aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-09-17 20:53:44 +0000
committerAndrzej Janik <[email protected]>2021-09-17 20:53:44 +0000
commitd5a4b068dd9bf72a0e1b6448583ebad609ed72c1 (patch)
tree988aaa9772d78a0963bb54a96f93f95de15877b5
parent6ef19d65010164a7cc8408663eb189b64f44d26a (diff)
downloadZLUDA-d5a4b068dd9bf72a0e1b6448583ebad609ed72c1.tar.gz
ZLUDA-d5a4b068dd9bf72a0e1b6448583ebad609ed72c1.zip
Redo handling of sregs
-rw-r--r--ptx/lib/zluda_ptx_impl.bcbin31940 -> 33884 bytes
-rw-r--r--ptx/lib/zluda_ptx_impl.cl16
-rw-r--r--ptx/src/test/spirv_run/lanemask_lt.spvtxt69
-rw-r--r--ptx/src/test/spirv_run/ntid.spvtxt67
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt134
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt138
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt154
-rw-r--r--ptx/src/translate.rs379
8 files changed, 511 insertions, 446 deletions
diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc
index 7aa12c8..e2f956d 100644
--- a/ptx/lib/zluda_ptx_impl.bc
+++ b/ptx/lib/zluda_ptx_impl.bc
Binary files differ
diff --git a/ptx/lib/zluda_ptx_impl.cl b/ptx/lib/zluda_ptx_impl.cl
index d439795..0870fb5 100644
--- a/ptx/lib/zluda_ptx_impl.cl
+++ b/ptx/lib/zluda_ptx_impl.cl
@@ -297,6 +297,22 @@ atomic_add(atom_acq_rel_sys_shared_add_f64, memory_order_acq_rel, memory_order_a
return (uint)__builtin_amdgcn_uicmp(1, 0, 33);
}
+ uint FUNC(sreg_tid)(uchar dim) {
+ return (uint)get_local_id(dim);
+ }
+
+ uint FUNC(sreg_ntid)(uchar dim) {
+ return (uint)get_local_size(dim);
+ }
+
+ uint FUNC(sreg_ctaid)(uchar dim) {
+ return (uint)get_group_id(dim);
+ }
+
+ uint FUNC(sreg_nctaid)(uchar dim) {
+ return (uint)get_num_groups(dim);
+ }
+
uint FUNC(sreg_clock)() {
return (uint)__builtin_amdgcn_s_memtime();
}
diff --git a/ptx/src/test/spirv_run/lanemask_lt.spvtxt b/ptx/src/test/spirv_run/lanemask_lt.spvtxt
index 0753c95..3de53ce 100644
--- a/ptx/src/test/spirv_run/lanemask_lt.spvtxt
+++ b/ptx/src/test/spirv_run/lanemask_lt.spvtxt
@@ -7,39 +7,64 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %18 = OpExtInstImport "OpenCL.std"
+ %40 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "activemask"
+ OpEntryPoint Kernel %1 "lanemask_lt"
OpExecutionMode %1 ContractionOff
- OpDecorate %15 LinkageAttributes "__zluda_ptx_impl__activemask" Import
+ OpDecorate %11 LinkageAttributes "__zluda_ptx_impl__sreg_lanemask_lt" Import
%void = OpTypeVoid
%uint = OpTypeInt 32 0
- %21 = OpTypeFunction %uint
+ %43 = OpTypeFunction %uint
%ulong = OpTypeInt 64 0
- %23 = OpTypeFunction %void %ulong %ulong
+ %45 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
- %15 = OpFunction %uint None %21
+ %uint_1 = OpConstant %uint 1
+ %11 = OpFunction %uint None %43
OpFunctionEnd
- %1 = OpFunction %void None %23
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %14 = OpLabel
+ %1 = OpFunction %void None %45
+ %13 = OpFunctionParameter %ulong
+ %14 = OpFunctionParameter %ulong
+ %38 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
- %5 = OpVariable %_ptr_Function_uint Function
- OpStore %2 %6
- OpStore %3 %7
- %8 = OpLoad %ulong %3 Aligned 8
- OpStore %4 %8
- %9 = OpFunctionCall %uint %15
- OpStore %5 %9
- %10 = OpLoad %ulong %4
- %11 = OpLoad %uint %5
- %12 = OpConvertUToPtr %_ptr_Generic_uint %10
- %13 = OpCopyObject %uint %11
- OpStore %12 %13 Aligned 4
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ %8 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %13
+ OpStore %3 %14
+ %15 = OpLoad %ulong %2 Aligned 8
+ OpStore %4 %15
+ %16 = OpLoad %ulong %3 Aligned 8
+ OpStore %5 %16
+ %18 = OpLoad %ulong %4
+ %29 = OpConvertUToPtr %_ptr_Generic_uint %18
+ %28 = OpLoad %uint %29 Aligned 4
+ %17 = OpCopyObject %uint %28
+ OpStore %6 %17
+ %20 = OpLoad %uint %6
+ %31 = OpCopyObject %uint %20
+ %30 = OpIAdd %uint %31 %uint_1
+ %19 = OpCopyObject %uint %30
+ OpStore %7 %19
+ %10 = OpFunctionCall %uint %11
+ %32 = OpCopyObject %uint %10
+ %21 = OpCopyObject %uint %32
+ OpStore %8 %21
+ %23 = OpLoad %uint %7
+ %24 = OpLoad %uint %8
+ %34 = OpCopyObject %uint %23
+ %35 = OpCopyObject %uint %24
+ %33 = OpIAdd %uint %34 %35
+ %22 = OpCopyObject %uint %33
+ OpStore %7 %22
+ %25 = OpLoad %ulong %5
+ %26 = OpLoad %uint %7
+ %36 = OpConvertUToPtr %_ptr_Generic_uint %25
+ %37 = OpCopyObject %uint %26
+ OpStore %36 %37 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt
index e5f343c..6754ce4 100644
--- a/ptx/src/test/spirv_run/ntid.spvtxt
+++ b/ptx/src/test/spirv_run/ntid.spvtxt
@@ -7,55 +7,54 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %31 = OpExtInstImport "OpenCL.std"
+ %30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "ntid"
OpExecutionMode %1 ContractionOff
- OpDecorate %24 LinkageAttributes "get_local_size" Import
+ OpDecorate %11 LinkageAttributes "__zluda_ptx_impl__sreg_ntid" Import
%void = OpTypeVoid
- %ulong = OpTypeInt 64 0
%uint = OpTypeInt 32 0
- %35 = OpTypeFunction %ulong %uint
+ %uchar = OpTypeInt 8 0
+ %34 = OpTypeFunction %uint %uchar
+ %ulong = OpTypeInt 64 0
%36 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
- %uint_0 = OpConstant %uint 0
- %24 = OpFunction %ulong None %35
- %26 = OpFunctionParameter %uint
+ %uchar_0 = OpConstant %uchar 0
+ %11 = OpFunction %uint None %34
+ %13 = OpFunctionParameter %uchar
OpFunctionEnd
%1 = OpFunction %void None %36
- %9 = OpFunctionParameter %ulong
- %10 = OpFunctionParameter %ulong
- %29 = OpLabel
+ %14 = OpFunctionParameter %ulong
+ %15 = OpFunctionParameter %ulong
+ %28 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_uint Function
- OpStore %2 %9
- OpStore %3 %10
- %11 = OpLoad %ulong %2 Aligned 8
- OpStore %4 %11
- %12 = OpLoad %ulong %3 Aligned 8
- OpStore %5 %12
- %14 = OpLoad %ulong %4
- %27 = OpConvertUToPtr %_ptr_Generic_uint %14
- %13 = OpLoad %uint %27 Aligned 4
- OpStore %6 %13
- %23 = OpFunctionCall %ulong %24 %uint_0
- %40 = OpBitcast %ulong %23
- %16 = OpUConvert %uint %40
- %15 = OpCopyObject %uint %16
- OpStore %7 %15
- %18 = OpLoad %uint %6
- %19 = OpLoad %uint %7
- %17 = OpIAdd %uint %18 %19
- OpStore %6 %17
- %20 = OpLoad %ulong %5
- %21 = OpLoad %uint %6
- %28 = OpConvertUToPtr %_ptr_Generic_uint %20
- OpStore %28 %21 Aligned 4
+ OpStore %2 %14
+ OpStore %3 %15
+ %16 = OpLoad %ulong %2 Aligned 8
+ OpStore %4 %16
+ %17 = OpLoad %ulong %3 Aligned 8
+ OpStore %5 %17
+ %19 = OpLoad %ulong %4
+ %26 = OpConvertUToPtr %_ptr_Generic_uint %19
+ %18 = OpLoad %uint %26 Aligned 4
+ OpStore %6 %18
+ %10 = OpFunctionCall %uint %11 %uchar_0
+ %20 = OpCopyObject %uint %10
+ OpStore %7 %20
+ %22 = OpLoad %uint %6
+ %23 = OpLoad %uint %7
+ %21 = OpIAdd %uint %22 %23
+ OpStore %6 %21
+ %24 = OpLoad %ulong %5
+ %25 = OpLoad %uint %6
+ %27 = OpConvertUToPtr %_ptr_Generic_uint %24
+ OpStore %27 %25 Aligned 4
OpReturn
- OpFunctionEnd \ No newline at end of file
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
index b99fb50..e2d4db6 100644
--- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
@@ -7,89 +7,87 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %57 = OpExtInstImport "OpenCL.std"
+ %56 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "stateful_ld_st_ntid"
OpExecutionMode %1 ContractionOff
- OpDecorate %44 LinkageAttributes "_Z12get_local_idj" Import
+ OpDecorate %12 LinkageAttributes "__zluda_ptx_impl__sreg_tid" Import
%void = OpTypeVoid
- %ulong = OpTypeInt 64 0
%uint = OpTypeInt 32 0
- %61 = OpTypeFunction %ulong %uint
%uchar = OpTypeInt 8 0
+ %60 = OpTypeFunction %uint %uchar
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
- %64 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+ %62 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%_ptr_Function_uint = OpTypePointer Function %uint
+ %ulong = OpTypeInt 64 0
%_ptr_Function_ulong = OpTypePointer Function %ulong
- %uint_0 = OpConstant %uint 0
+ %uchar_0 = OpConstant %uchar 0
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
- %44 = OpFunction %ulong None %61
- %46 = OpFunctionParameter %uint
+ %12 = OpFunction %uint None %60
+ %14 = OpFunctionParameter %uchar
OpFunctionEnd
- %1 = OpFunction %void None %64
- %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
- %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
- %55 = OpLabel
- %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %13 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %1 = OpFunction %void None %62
+ %25 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %26 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %54 = OpLabel
+ %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_ulong Function
%8 = OpVariable %_ptr_Function_ulong Function
- OpStore %12 %20
- OpStore %13 %21
- %48 = OpBitcast %_ptr_Function_ulong %12
- %47 = OpLoad %ulong %48 Aligned 8
- %14 = OpCopyObject %ulong %47
- %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14
- OpStore %10 %22
- %50 = OpBitcast %_ptr_Function_ulong %13
- %49 = OpLoad %ulong %50 Aligned 8
- %15 = OpCopyObject %ulong %49
- %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15
- OpStore %11 %23
- %24 = OpLoad %_ptr_CrossWorkgroup_uchar %10
- %17 = OpConvertPtrToU %ulong %24
- %16 = OpCopyObject %ulong %17
- %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %16
- OpStore %10 %25
- %26 = OpLoad %_ptr_CrossWorkgroup_uchar %11
- %19 = OpConvertPtrToU %ulong %26
- %18 = OpCopyObject %ulong %19
- %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18
- OpStore %11 %27
- %43 = OpFunctionCall %ulong %44 %uint_0
- %68 = OpBitcast %ulong %43
- %29 = OpUConvert %uint %68
- %28 = OpCopyObject %uint %29
- OpStore %6 %28
- %31 = OpLoad %uint %6
- %69 = OpBitcast %uint %31
- %30 = OpUConvert %ulong %69
- OpStore %7 %30
- %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10
- %34 = OpLoad %ulong %7
- %51 = OpCopyObject %ulong %34
- %70 = OpBitcast %_ptr_CrossWorkgroup_uchar %33
+ OpStore %17 %25
+ OpStore %18 %26
+ %47 = OpBitcast %_ptr_Function_ulong %17
+ %46 = OpLoad %ulong %47 Aligned 8
+ %19 = OpCopyObject %ulong %46
+ %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19
+ OpStore %15 %27
+ %49 = OpBitcast %_ptr_Function_ulong %18
+ %48 = OpLoad %ulong %49 Aligned 8
+ %20 = OpCopyObject %ulong %48
+ %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20
+ OpStore %16 %28
+ %29 = OpLoad %_ptr_CrossWorkgroup_uchar %15
+ %22 = OpConvertPtrToU %ulong %29
+ %21 = OpCopyObject %ulong %22
+ %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %21
+ OpStore %15 %30
+ %31 = OpLoad %_ptr_CrossWorkgroup_uchar %16
+ %24 = OpConvertPtrToU %ulong %31
+ %23 = OpCopyObject %ulong %24
+ %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23
+ OpStore %16 %32
+ %11 = OpFunctionCall %uint %12 %uchar_0
+ %33 = OpCopyObject %uint %11
+ OpStore %6 %33
+ %35 = OpLoad %uint %6
+ %67 = OpBitcast %uint %35
+ %34 = OpUConvert %ulong %67
+ OpStore %7 %34
+ %37 = OpLoad %_ptr_CrossWorkgroup_uchar %15
+ %38 = OpLoad %ulong %7
+ %50 = OpCopyObject %ulong %38
+ %68 = OpBitcast %_ptr_CrossWorkgroup_uchar %37
+ %69 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %68 %50
+ %36 = OpBitcast %_ptr_CrossWorkgroup_uchar %69
+ OpStore %15 %36
+ %40 = OpLoad %_ptr_CrossWorkgroup_uchar %16
+ %41 = OpLoad %ulong %7
+ %51 = OpCopyObject %ulong %41
+ %70 = OpBitcast %_ptr_CrossWorkgroup_uchar %40
%71 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %70 %51
- %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %71
- OpStore %10 %32
- %36 = OpLoad %_ptr_CrossWorkgroup_uchar %11
- %37 = OpLoad %ulong %7
- %52 = OpCopyObject %ulong %37
- %72 = OpBitcast %_ptr_CrossWorkgroup_uchar %36
- %73 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %72 %52
- %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %73
- OpStore %11 %35
- %39 = OpLoad %_ptr_CrossWorkgroup_uchar %10
- %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %39
- %38 = OpLoad %ulong %53 Aligned 8
- OpStore %8 %38
- %40 = OpLoad %_ptr_CrossWorkgroup_uchar %11
- %41 = OpLoad %ulong %8
- %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %40
- OpStore %54 %41 Aligned 8
+ %39 = OpBitcast %_ptr_CrossWorkgroup_uchar %71
+ OpStore %16 %39
+ %43 = OpLoad %_ptr_CrossWorkgroup_uchar %15
+ %52 = OpBitcast %_ptr_CrossWorkgroup_ulong %43
+ %42 = OpLoad %ulong %52 Aligned 8
+ OpStore %8 %42
+ %44 = OpLoad %_ptr_CrossWorkgroup_uchar %16
+ %45 = OpLoad %ulong %8
+ %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %44
+ OpStore %53 %45 Aligned 8
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt
index 0239632..5da0ef3 100644
--- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt
@@ -7,93 +7,91 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %65 = OpExtInstImport "OpenCL.std"
+ %64 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain"
OpExecutionMode %1 ContractionOff
- OpDecorate %52 LinkageAttributes "_Z12get_local_idj" Import
+ OpDecorate %16 LinkageAttributes "__zluda_ptx_impl__sreg_tid" Import
%void = OpTypeVoid
- %ulong = OpTypeInt 64 0
%uint = OpTypeInt 32 0
- %69 = OpTypeFunction %ulong %uint
%uchar = OpTypeInt 8 0
+ %68 = OpTypeFunction %uint %uchar
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
- %72 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+ %70 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%_ptr_Function_uint = OpTypePointer Function %uint
+ %ulong = OpTypeInt 64 0
%_ptr_Function_ulong = OpTypePointer Function %ulong
- %uint_0 = OpConstant %uint 0
+ %uchar_0 = OpConstant %uchar 0
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
- %52 = OpFunction %ulong None %69
- %54 = OpFunctionParameter %uint
+ %16 = OpFunction %uint None %68
+ %18 = OpFunctionParameter %uchar
OpFunctionEnd
- %1 = OpFunction %void None %72
- %28 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
- %29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
- %63 = OpLabel
+ %1 = OpFunction %void None %70
+ %33 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %34 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %62 = OpLabel
+ %25 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %26 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %22 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %23 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %24 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%10 = OpVariable %_ptr_Function_uint Function
%11 = OpVariable %_ptr_Function_ulong Function
%12 = OpVariable %_ptr_Function_ulong Function
- OpStore %20 %28
- OpStore %21 %29
- %56 = OpBitcast %_ptr_Function_ulong %20
- %55 = OpLoad %ulong %56 Aligned 8
- %22 = OpCopyObject %ulong %55
- %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22
- OpStore %14 %30
- %58 = OpBitcast %_ptr_Function_ulong %21
- %57 = OpLoad %ulong %58 Aligned 8
- %23 = OpCopyObject %ulong %57
- %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23
- OpStore %17 %31
- %32 = OpLoad %_ptr_CrossWorkgroup_uchar %14
- %25 = OpConvertPtrToU %ulong %32
- %24 = OpCopyObject %ulong %25
- %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24
- OpStore %15 %33
- %34 = OpLoad %_ptr_CrossWorkgroup_uchar %17
- %27 = OpConvertPtrToU %ulong %34
- %26 = OpCopyObject %ulong %27
- %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26
- OpStore %18 %35
- %51 = OpFunctionCall %ulong %52 %uint_0
- %76 = OpBitcast %ulong %51
- %37 = OpUConvert %uint %76
- %36 = OpCopyObject %uint %37
- OpStore %10 %36
- %39 = OpLoad %uint %10
- %77 = OpBitcast %uint %39
- %38 = OpUConvert %ulong %77
- OpStore %11 %38
- %41 = OpLoad %_ptr_CrossWorkgroup_uchar %15
- %42 = OpLoad %ulong %11
- %59 = OpCopyObject %ulong %42
- %78 = OpBitcast %_ptr_CrossWorkgroup_uchar %41
+ OpStore %25 %33
+ OpStore %26 %34
+ %55 = OpBitcast %_ptr_Function_ulong %25
+ %54 = OpLoad %ulong %55 Aligned 8
+ %27 = OpCopyObject %ulong %54
+ %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %27
+ OpStore %19 %35
+ %57 = OpBitcast %_ptr_Function_ulong %26
+ %56 = OpLoad %ulong %57 Aligned 8
+ %28 = OpCopyObject %ulong %56
+ %36 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %28
+ OpStore %22 %36
+ %37 = OpLoad %_ptr_CrossWorkgroup_uchar %19
+ %30 = OpConvertPtrToU %ulong %37
+ %29 = OpCopyObject %ulong %30
+ %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %29
+ OpStore %20 %38
+ %39 = OpLoad %_ptr_CrossWorkgroup_uchar %22
+ %32 = OpConvertPtrToU %ulong %39
+ %31 = OpCopyObject %ulong %32
+ %40 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %31
+ OpStore %23 %40
+ %15 = OpFunctionCall %uint %16 %uchar_0
+ %41 = OpCopyObject %uint %15
+ OpStore %10 %41
+ %43 = OpLoad %uint %10
+ %75 = OpBitcast %uint %43
+ %42 = OpUConvert %ulong %75
+ OpStore %11 %42
+ %45 = OpLoad %_ptr_CrossWorkgroup_uchar %20
+ %46 = OpLoad %ulong %11
+ %58 = OpCopyObject %ulong %46
+ %76 = OpBitcast %_ptr_CrossWorkgroup_uchar %45
+ %77 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %76 %58
+ %44 = OpBitcast %_ptr_CrossWorkgroup_uchar %77
+ OpStore %21 %44
+ %48 = OpLoad %_ptr_CrossWorkgroup_uchar %23
+ %49 = OpLoad %ulong %11
+ %59 = OpCopyObject %ulong %49
+ %78 = OpBitcast %_ptr_CrossWorkgroup_uchar %48
%79 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %78 %59
- %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %79
- OpStore %16 %40
- %44 = OpLoad %_ptr_CrossWorkgroup_uchar %18
- %45 = OpLoad %ulong %11
- %60 = OpCopyObject %ulong %45
- %80 = OpBitcast %_ptr_CrossWorkgroup_uchar %44
- %81 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %80 %60
- %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %81
- OpStore %19 %43
- %47 = OpLoad %_ptr_CrossWorkgroup_uchar %16
- %61 = OpBitcast %_ptr_CrossWorkgroup_ulong %47
- %46 = OpLoad %ulong %61 Aligned 8
- OpStore %12 %46
- %48 = OpLoad %_ptr_CrossWorkgroup_uchar %19
- %49 = OpLoad %ulong %12
- %62 = OpBitcast %_ptr_CrossWorkgroup_ulong %48
- OpStore %62 %49 Aligned 8
+ %47 = OpBitcast %_ptr_CrossWorkgroup_uchar %79
+ OpStore %24 %47
+ %51 = OpLoad %_ptr_CrossWorkgroup_uchar %21
+ %60 = OpBitcast %_ptr_CrossWorkgroup_ulong %51
+ %50 = OpLoad %ulong %60 Aligned 8
+ OpStore %12 %50
+ %52 = OpLoad %_ptr_CrossWorkgroup_uchar %24
+ %53 = OpLoad %ulong %12
+ %61 = OpBitcast %_ptr_CrossWorkgroup_ulong %52
+ OpStore %61 %53 Aligned 8
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt
index 987e205..0ef5d28 100644
--- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt
@@ -7,103 +7,101 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %71 = OpExtInstImport "OpenCL.std"
+ %70 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "stateful_ld_st_ntid_sub"
OpExecutionMode %1 ContractionOff
- OpDecorate %54 LinkageAttributes "_Z12get_local_idj" Import
+ OpDecorate %16 LinkageAttributes "__zluda_ptx_impl__sreg_tid" Import
%void = OpTypeVoid
- %ulong = OpTypeInt 64 0
%uint = OpTypeInt 32 0
- %75 = OpTypeFunction %ulong %uint
%uchar = OpTypeInt 8 0
+ %74 = OpTypeFunction %uint %uchar
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
- %78 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+ %76 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%_ptr_Function_uint = OpTypePointer Function %uint
+ %ulong = OpTypeInt 64 0
%_ptr_Function_ulong = OpTypePointer Function %ulong
- %uint_0 = OpConstant %uint 0
+ %uchar_0 = OpConstant %uchar 0
%ulong_0 = OpConstant %ulong 0
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%ulong_0_0 = OpConstant %ulong 0
- %54 = OpFunction %ulong None %75
- %56 = OpFunctionParameter %uint
+ %16 = OpFunction %uint None %74
+ %18 = OpFunctionParameter %uchar
OpFunctionEnd
- %1 = OpFunction %void None %78
- %30 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
- %31 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
- %69 = OpLabel
+ %1 = OpFunction %void None %76
+ %35 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %36 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %68 = OpLabel
+ %25 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %26 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %22 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %23 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %24 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%10 = OpVariable %_ptr_Function_uint Function
%11 = OpVariable %_ptr_Function_ulong Function
%12 = OpVariable %_ptr_Function_ulong Function
- OpStore %20 %30
- OpStore %21 %31
- %62 = OpBitcast %_ptr_Function_ulong %20
- %61 = OpLoad %ulong %62 Aligned 8
- %22 = OpCopyObject %ulong %61
- %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22
- OpStore %14 %32
- %64 = OpBitcast %_ptr_Function_ulong %21
- %63 = OpLoad %ulong %64 Aligned 8
- %23 = OpCopyObject %ulong %63
- %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23
- OpStore %17 %33
- %34 = OpLoad %_ptr_CrossWorkgroup_uchar %14
- %25 = OpConvertPtrToU %ulong %34
- %24 = OpCopyObject %ulong %25
- %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24
- OpStore %15 %35
- %36 = OpLoad %_ptr_CrossWorkgroup_uchar %17
- %27 = OpConvertPtrToU %ulong %36
- %26 = OpCopyObject %ulong %27
- %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26
- OpStore %18 %37
- %53 = OpFunctionCall %ulong %54 %uint_0
- %82 = OpBitcast %ulong %53
- %39 = OpUConvert %uint %82
- %38 = OpCopyObject %uint %39
- OpStore %10 %38
- %41 = OpLoad %uint %10
- %83 = OpBitcast %uint %41
- %40 = OpUConvert %ulong %83
- OpStore %11 %40
- %42 = OpLoad %ulong %11
- %65 = OpCopyObject %ulong %42
- %28 = OpSNegate %ulong %65
- %44 = OpLoad %_ptr_CrossWorkgroup_uchar %15
- %84 = OpBitcast %_ptr_CrossWorkgroup_uchar %44
- %85 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %84 %28
- %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %85
- OpStore %16 %43
- %45 = OpLoad %ulong %11
- %66 = OpCopyObject %ulong %45
- %29 = OpSNegate %ulong %66
- %47 = OpLoad %_ptr_CrossWorkgroup_uchar %18
- %86 = OpBitcast %_ptr_CrossWorkgroup_uchar %47
- %87 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %86 %29
- %46 = OpBitcast %_ptr_CrossWorkgroup_uchar %87
- OpStore %19 %46
- %49 = OpLoad %_ptr_CrossWorkgroup_uchar %16
- %67 = OpBitcast %_ptr_CrossWorkgroup_ulong %49
+ OpStore %25 %35
+ OpStore %26 %36
+ %61 = OpBitcast %_ptr_Function_ulong %25
+ %60 = OpLoad %ulong %61 Aligned 8
+ %27 = OpCopyObject %ulong %60
+ %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %27
+ OpStore %19 %37
+ %63 = OpBitcast %_ptr_Function_ulong %26
+ %62 = OpLoad %ulong %63 Aligned 8
+ %28 = OpCopyObject %ulong %62
+ %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %28
+ OpStore %22 %38
+ %39 = OpLoad %_ptr_CrossWorkgroup_uchar %19
+ %30 = OpConvertPtrToU %ulong %39
+ %29 = OpCopyObject %ulong %30
+ %40 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %29
+ OpStore %20 %40
+ %41 = OpLoad %_ptr_CrossWorkgroup_uchar %22
+ %32 = OpConvertPtrToU %ulong %41
+ %31 = OpCopyObject %ulong %32
+ %42 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %31
+ OpStore %23 %42
+ %15 = OpFunctionCall %uint %16 %uchar_0
+ %43 = OpCopyObject %uint %15
+ OpStore %10 %43
+ %45 = OpLoad %uint %10
+ %81 = OpBitcast %uint %45
+ %44 = OpUConvert %ulong %81
+ OpStore %11 %44
+ %46 = OpLoad %ulong %11
+ %64 = OpCopyObject %ulong %46
+ %33 = OpSNegate %ulong %64
+ %48 = OpLoad %_ptr_CrossWorkgroup_uchar %20
+ %82 = OpBitcast %_ptr_CrossWorkgroup_uchar %48
+ %83 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %82 %33
+ %47 = OpBitcast %_ptr_CrossWorkgroup_uchar %83
+ OpStore %21 %47
+ %49 = OpLoad %ulong %11
+ %65 = OpCopyObject %ulong %49
+ %34 = OpSNegate %ulong %65
+ %51 = OpLoad %_ptr_CrossWorkgroup_uchar %23
+ %84 = OpBitcast %_ptr_CrossWorkgroup_uchar %51
+ %85 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %84 %34
+ %50 = OpBitcast %_ptr_CrossWorkgroup_uchar %85
+ OpStore %24 %50
+ %53 = OpLoad %_ptr_CrossWorkgroup_uchar %21
+ %66 = OpBitcast %_ptr_CrossWorkgroup_ulong %53
+ %87 = OpBitcast %_ptr_CrossWorkgroup_uchar %66
+ %88 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %87 %ulong_0
+ %57 = OpBitcast %_ptr_CrossWorkgroup_ulong %88
+ %52 = OpLoad %ulong %57 Aligned 8
+ OpStore %12 %52
+ %54 = OpLoad %_ptr_CrossWorkgroup_uchar %24
+ %55 = OpLoad %ulong %12
+ %67 = OpBitcast %_ptr_CrossWorkgroup_ulong %54
%89 = OpBitcast %_ptr_CrossWorkgroup_uchar %67
- %90 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %89 %ulong_0
- %58 = OpBitcast %_ptr_CrossWorkgroup_ulong %90
- %48 = OpLoad %ulong %58 Aligned 8
- OpStore %12 %48
- %50 = OpLoad %_ptr_CrossWorkgroup_uchar %19
- %51 = OpLoad %ulong %12
- %68 = OpBitcast %_ptr_CrossWorkgroup_ulong %50
- %91 = OpBitcast %_ptr_CrossWorkgroup_uchar %68
- %92 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %91 %ulong_0_0
- %60 = OpBitcast %_ptr_CrossWorkgroup_ulong %92
- OpStore %60 %51 Aligned 8
+ %90 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %89 %ulong_0_0
+ %59 = OpBitcast %_ptr_CrossWorkgroup_ulong %90
+ OpStore %59 %55 Aligned 8
OpReturn
OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 15dcdd1..8135c51 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -10,8 +10,6 @@ use rspirv::binary::{Assemble, Disassemble};
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.bc");
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
-const ZLUDA_PTX_PREFIX_SREG_CLOCK: &'static str = "__zluda_ptx_impl__sreg_clock";
-const ZLUDA_PTX_PREFIX_SREG_LANEMASK_LT: &'static str = "__zluda_ptx_impl__sreg_lanemask_lt";
quick_error! {
#[derive(Debug)]
@@ -426,8 +424,8 @@ pub struct KernelInfo {
pub uses_shared_mem: bool,
}
-pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateError> {
- let mut id_defs = GlobalStringIdResolver::new(1);
+pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
+ let mut id_defs = GlobalStringIdResolver::<'input>::new(1);
let mut ptx_impl_imports = HashMap::new();
let directives = ast
.directives
@@ -1135,9 +1133,9 @@ fn emit_memory_model(builder: &mut dr::Builder) {
);
}
-fn translate_directive<'input>(
- id_defs: &mut GlobalStringIdResolver<'input>,
- ptx_impl_imports: &mut HashMap<String, Directive<'input>>,
+fn translate_directive<'input, 'a>(
+ id_defs: &'a mut GlobalStringIdResolver<'input>,
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
) -> Result<Option<Directive<'input>>, TranslateError> {
Ok(match d {
@@ -1157,11 +1155,11 @@ fn translate_directive<'input>(
})
}
-fn translate_function<'a>(
- id_defs: &mut GlobalStringIdResolver<'a>,
- ptx_impl_imports: &mut HashMap<String, Directive<'a>>,
- f: ast::ParsedFunction<'a>,
-) -> Result<Option<Function<'a>>, TranslateError> {
+fn translate_function<'input, 'a>(
+ id_defs: &'a mut GlobalStringIdResolver<'input>,
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ f: ast::ParsedFunction<'input>,
+) -> Result<Option<Function<'input>>, TranslateError> {
let import_as = match &f.func_directive {
ast::MethodDeclaration {
name: ast::MethodName::Func("__assertfail"),
@@ -1206,7 +1204,7 @@ fn rename_fn_params<'a, 'b>(
}
fn to_ssa<'input, 'b>(
- ptx_impl_imports: &mut HashMap<String, Directive>,
+ ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
@@ -1231,6 +1229,8 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
+ let typed_statements =
+ fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
@@ -1238,8 +1238,6 @@ fn to_ssa<'input, 'b>(
&mut numeric_id_defs,
&mut (*func_decl).borrow_mut(),
)?;
- let ssa_statements =
- fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@@ -1257,90 +1255,147 @@ fn to_ssa<'input, 'b>(
})
}
-fn fix_special_registers(
- ptx_impl_imports: &mut HashMap<String, Directive>,
+fn fix_special_registers2<'a, 'b, 'input>(
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
typed_statements: Vec<TypedStatement>,
- numeric_id_defs: &mut NumericIdResolver,
+ numeric_id_defs: &'a mut NumericIdResolver<'b>,
) -> Result<Vec<TypedStatement>, TranslateError> {
- let mut result = Vec::with_capacity(typed_statements.len());
+ let result = Vec::with_capacity(typed_statements.len());
+ let mut sreg_sresolver = SpecialRegisterResolver {
+ ptx_impl_imports,
+ numeric_id_defs,
+ result,
+ };
for s in typed_statements {
match s {
- Statement::LoadVar(
- details
- @
- LoadVarDetails {
- member_index: Some((_, Some(_))),
- ..
- },
- ) => {
- let index = details.member_index.unwrap().0;
- let sreg = numeric_id_defs
- .special_registers
- .get(details.arg.src)
- .ok_or_else(|| error_unreachable())?;
- let (ocl_name, ocl_type) = sreg.get_opencl_fn_type();
- let index_constant = numeric_id_defs.register_intermediate(Some((
- ast::Type::Scalar(ast::ScalarType::U32),
- ast::StateSpace::Reg,
- )));
- result.push(Statement::Constant(ConstantDefinition {
- dst: index_constant,
- typ: ast::ScalarType::U32,
- value: ast::ImmediateValue::U64(index as u64),
- }));
- let fn_result = numeric_id_defs.register_intermediate(Some((
- ast::Type::Scalar(ocl_type),
- ast::StateSpace::Reg,
- )));
- let return_arguments =
- vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)];
- let input_arguments = vec![(
- TypedOperand::Reg(index_constant),
- ast::Type::Scalar(ast::ScalarType::U32),
- ast::StateSpace::Reg,
- )];
- let fn_call = register_external_fn_call(
- numeric_id_defs,
- ptx_impl_imports,
- ocl_name.to_string(),
- return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
- input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
- )?;
- result.push(Statement::Call(ResolvedCall {
- uniform: false,
- return_arguments,
- name: fn_call,
- input_arguments,
- }));
- result.push(Statement::Conversion(ImplicitConversion {
- src: fn_result,
- dst: details.arg.dst,
- from_type: ast::Type::Scalar(ocl_type),
- from_space: ast::StateSpace::Reg,
- to_type: ast::Type::Scalar(ast::ScalarType::U32),
- to_space: ast::StateSpace::Reg,
- kind: ConversionKind::Default,
- }));
+ Statement::Call(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::Instruction(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
}
- s => result.push(s),
+ Statement::Conditional(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::Conversion(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::PtrAccess(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::RepackVector(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ s @ Statement::Variable(_)
+ | s @ Statement::Label(_)
+ | s @ Statement::FunctionPointer(_) => sreg_sresolver.result.push(s),
+ _ => return Err(error_unreachable()),
}
}
- Ok(result)
+ Ok(sreg_sresolver.result)
}
-fn get_sreg_id_scalar_type(
- numeric_id_defs: &mut NumericIdResolver,
- sreg: PtxSpecialRegister,
-) -> Option<(spirv::Word, ast::ScalarType, u8)> {
- match sreg.normalized_sreg_and_type() {
- Some((normalized_sreg, typ, vec_width)) => Some((
- numeric_id_defs
- .special_registers
- .get_or_add(numeric_id_defs.current_id, normalized_sreg),
- typ,
- vec_width,
- )),
- None => None,
+struct SpecialRegisterResolver<'a, 'b, 'input> {
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ numeric_id_defs: &'a mut NumericIdResolver<'b>,
+ result: Vec<TypedStatement>,
+}
+
+impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
+ fn replace_sreg(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ vector_index: Option<u8>,
+ ) -> Result<spirv::Word, TranslateError> {
+ if let Some(sreg) = self.numeric_id_defs.special_registers.get(desc.op) {
+ if desc.is_dst {
+ return Err(TranslateError::MismatchedType);
+ }
+ let input_arguments = match (vector_index, sreg.get_function_input_type()) {
+ (Some(idx), Some(inp_type)) => {
+ if inp_type != ast::ScalarType::U8 {
+ return Err(TranslateError::Unreachable);
+ }
+ let constant = self.numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: constant,
+ typ: inp_type,
+ value: ast::ImmediateValue::U64(idx as u64),
+ }));
+ vec![(
+ TypedOperand::Reg(constant),
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )]
+ }
+ (None, None) => Vec::new(),
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
+ let return_type = sreg.get_function_return_type();
+ let fn_result = self.numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )));
+ let return_arguments = vec![(
+ fn_result,
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )];
+ let fn_call = register_external_fn_call(
+ self.numeric_id_defs,
+ self.ptx_impl_imports,
+ ocl_fn_name.to_string(),
+ return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ )?;
+ self.result.push(Statement::Call(ResolvedCall {
+ uniform: false,
+ return_arguments,
+ name: fn_call,
+ input_arguments,
+ }));
+ Ok(fn_result)
+ } else {
+ Ok(desc.op)
+ }
+ }
+}
+
+impl<'a, 'b, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
+ for SpecialRegisterResolver<'a, 'b, 'input>
+{
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<spirv::Word, TranslateError> {
+ self.replace_sreg(desc, None)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<TypedOperand>,
+ typ: &ast::Type,
+ state_space: ast::StateSpace,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match desc.op {
+ TypedOperand::Reg(reg) => TypedOperand::Reg(self.replace_sreg(desc.new_op(reg), None)?),
+ op @ TypedOperand::RegOffset(_, _) => op,
+ op @ TypedOperand::Imm(_) => op,
+ TypedOperand::VecMember(reg, idx) => {
+ TypedOperand::VecMember(self.replace_sreg(desc.new_op(reg), Some(idx))?, idx)
+ }
+ })
}
}
@@ -1968,22 +2023,8 @@ fn insert_mem_ssa_statements<'a, 'b>(
}
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
- Statement::Conditional(mut bra) => {
- let generated_id = id_def.register_intermediate(Some((
- ast::Type::Scalar(ast::ScalarType::Pred),
- ast::StateSpace::Reg,
- )));
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: Arg2 {
- dst: generated_id,
- src: bra.predicate,
- },
- state_space: ast::StateSpace::Reg,
- typ: ast::Type::Scalar(ast::ScalarType::Pred),
- member_index: None,
- }));
- bra.predicate = generated_id;
- result.push(Statement::Conditional(bra));
+ Statement::Conditional(bra) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, bra)?
}
Statement::Conversion(conv) => {
insert_mem_ssa_statement_default(id_def, &mut result, conv)?
@@ -1997,7 +2038,9 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::FunctionPointer(func_ptr) => {
insert_mem_ssa_statement_default(id_def, &mut result, func_ptr)?
}
- s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
+ s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
+ result.push(s)
+ }
_ => return Err(error_unreachable()),
}
}
@@ -4539,6 +4582,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
match statement {
l @ Statement::Label(_) => result.push(l),
c @ Statement::Conditional(_) => result.push(c),
+ c @ Statement::Constant(..) => result.push(c),
Statement::Variable(var) => {
if !remapped_ids.contains_key(&var.name) {
result.push(Statement::Variable(var));
@@ -4791,13 +4835,9 @@ fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister {
Tid,
- Tid64,
Ntid,
- Ntid64,
Ctaid,
- Ctaid64,
Nctaid,
- Nctaid64,
Clock,
LanemaskLt,
}
@@ -4817,71 +4857,43 @@ impl PtxSpecialRegister {
fn get_type(self) -> ast::Type {
match self {
- PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4),
- PtxSpecialRegister::Tid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
- PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4),
- PtxSpecialRegister::Ntid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
- PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
- PtxSpecialRegister::Ctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
- PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
- PtxSpecialRegister::Nctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
- PtxSpecialRegister::Clock => ast::Type::Scalar(ast::ScalarType::U32),
- PtxSpecialRegister::LanemaskLt => ast::Type::Scalar(ast::ScalarType::U32),
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4),
+ _ => ast::Type::Scalar(self.get_function_return_type()),
}
}
- fn get_scalar_type(self) -> ast::ScalarType {
+ fn get_function_return_type(self) -> ast::ScalarType {
match self {
- PtxSpecialRegister::Tid
- | PtxSpecialRegister::Ntid
- | PtxSpecialRegister::Ctaid
- | PtxSpecialRegister::Nctaid
- | PtxSpecialRegister::Clock
- | PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32,
- PtxSpecialRegister::Tid64
- | PtxSpecialRegister::Ntid64
- | PtxSpecialRegister::Ctaid64
- | PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64,
+ PtxSpecialRegister::Tid => ast::ScalarType::U32,
+ PtxSpecialRegister::Ntid => ast::ScalarType::U32,
+ PtxSpecialRegister::Ctaid => ast::ScalarType::U32,
+ PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
+ PtxSpecialRegister::Clock => ast::ScalarType::U32,
+ PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32,
}
}
- fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) {
+ fn get_function_input_type(self) -> Option<ast::ScalarType> {
match self {
- PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
- ("_Z12get_local_idj", ast::ScalarType::U64)
- }
- PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => {
- ("_Z14get_local_sizej", ast::ScalarType::U64)
- }
- PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => {
- ("_Z12get_group_idj", ast::ScalarType::U64)
- }
- PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
- ("_Z14get_num_groupsj", ast::ScalarType::U64)
- }
- PtxSpecialRegister::Clock => (ZLUDA_PTX_PREFIX_SREG_CLOCK, ast::ScalarType::U32),
- PtxSpecialRegister::LanemaskLt => {
- (ZLUDA_PTX_PREFIX_SREG_LANEMASK_LT, ast::ScalarType::U32)
- }
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8),
+ PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None,
}
}
- fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> {
+ fn get_unprefixed_function_name(self) -> &'static str {
match self {
- PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)),
- PtxSpecialRegister::Ntid => Some((PtxSpecialRegister::Ntid64, ast::ScalarType::U64, 3)),
- PtxSpecialRegister::Ctaid => {
- Some((PtxSpecialRegister::Ctaid64, ast::ScalarType::U64, 3))
- }
- PtxSpecialRegister::Nctaid => {
- Some((PtxSpecialRegister::Nctaid64, ast::ScalarType::U64, 3))
- }
- PtxSpecialRegister::Tid64
- | PtxSpecialRegister::Ntid64
- | PtxSpecialRegister::Ctaid64
- | PtxSpecialRegister::Nctaid64
- | PtxSpecialRegister::Clock => None,
- PtxSpecialRegister::LanemaskLt => None,
+ PtxSpecialRegister::Tid => "sreg_tid",
+ PtxSpecialRegister::Ntid => "sreg_ntid",
+ PtxSpecialRegister::Ctaid => "sreg_ctaid",
+ PtxSpecialRegister::Nctaid => "sreg_nctaid",
+ PtxSpecialRegister::Clock => "sreg_clock",
+ PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt",
}
}
}
@@ -4899,16 +4911,6 @@ impl SpecialRegistersMap {
}
}
- fn builtins<'a>(&'a self) -> impl Iterator<Item = (PtxSpecialRegister, spirv::Word)> + 'a {
- self.reg_to_id.iter().filter_map(|(sreg, id)| {
- if sreg.normalized_sreg_and_type().is_none() {
- Some((*sreg, *id))
- } else {
- None
- }
- })
- }
-
fn interface(&self) -> Vec<spirv::Word> {
return Vec::new();
/*
@@ -6416,6 +6418,35 @@ struct BrachCondition {
if_false: spirv::Word,
}
+impl<From: ArgParamsEx<Id = spirv::Word>, To: ArgParamsEx<Id = spirv::Word>> Visitable<From, To>
+ for BrachCondition
+{
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<From, To>,
+ ) -> Result<Statement<ast::Instruction<To>, To>, TranslateError> {
+ let predicate = visitor.id(
+ ArgumentDescriptor {
+ op: self.predicate,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
+ )?;
+ let if_true = self.if_true;
+ let if_false = self.if_false;
+ Ok(Statement::Conditional(BrachCondition {
+ predicate,
+ if_true,
+ if_false,
+ }))
+ }
+}
+
#[derive(Clone)]
struct ImplicitConversion {
src: spirv::Word,