aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-11-19 22:12:12 +0100
committerAndrzej Janik <[email protected]>2020-11-19 22:12:12 +0100
commitf77b653d363a3b05d34d390874cec631ff948814 (patch)
tree5e9639b77647209ee79855c9b574cf457e875520
parenteac5fbd806639c42813d06095fd3911a4664538b (diff)
downloadZLUDA-f77b653d363a3b05d34d390874cec631ff948814.tar.gz
ZLUDA-f77b653d363a3b05d34d390874cec631ff948814.zip
Implement stateless-to-stateful optimization
-rw-r--r--ptx/src/ast.rs6
-rw-r--r--ptx/src/ptx.lalrpop12
-rw-r--r--ptx/src/test/spirv_run/atom_inc.spvtxt170
-rw-r--r--ptx/src/test/spirv_run/cvta.spvtxt83
-rw-r--r--ptx/src/test/spirv_run/extern_shared_call.spvtxt84
-rw-r--r--ptx/src/test/spirv_run/mod.rs18
-rw-r--r--ptx/src/test/spirv_run/reg_local.spvtxt12
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx31
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt89
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx35
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt93
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx35
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt105
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_simple.ptx25
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt65
-rw-r--r--ptx/src/translate.rs1119
16 files changed, 1495 insertions, 487 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 5a5f6be..367f060 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -766,6 +766,8 @@ sub_type! {
LdStType {
Scalar(LdStScalarType),
Vector(LdStScalarType, u8),
+ // Used in generated code
+ Pointer(PointerType, LdStateSpace),
}
}
@@ -774,6 +776,10 @@ impl From<LdStType> for PointerType {
match t {
LdStType::Scalar(t) => PointerType::Scalar(t.into()),
LdStType::Vector(t, len) => PointerType::Vector(t.into(), len),
+ LdStType::Pointer(PointerType::Scalar(scalar_type), space) => {
+ PointerType::Pointer(scalar_type, space)
+ }
+ LdStType::Pointer(..) => unreachable!(),
}
}
}
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 6c231b2..d2c235a 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -1237,18 +1237,18 @@ InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta
InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "cvta" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
+ "cvta" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
ast::Instruction::Cvta(ast::CvtaDetails {
- to: to,
- from: ast::CvtaStateSpace::Generic,
+ to: ast::CvtaStateSpace::Generic,
+ from,
size: s
},
a)
},
- "cvta" ".to" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
+ "cvta" ".to" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
ast::Instruction::Cvta(ast::CvtaDetails {
- to: ast::CvtaStateSpace::Generic,
- from: from,
+ to,
+ from: ast::CvtaStateSpace::Generic,
size: s
},
a)
diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt
index 6948cd9..fda26c5 100644
--- a/ptx/src/test/spirv_run/atom_inc.spvtxt
+++ b/ptx/src/test/spirv_run/atom_inc.spvtxt
@@ -1,89 +1,81 @@
-; SPIR-V
-; Version: 1.3
-; Generator: rspirv
-; Bound: 60
-OpCapability GenericPointer
-OpCapability Linkage
-OpCapability Addresses
-OpCapability Kernel
-OpCapability Int8
-OpCapability Int16
-OpCapability Int64
-OpCapability Float16
-OpCapability Float64
-; OpCapability FunctionFloatControlINTEL
-; OpExtension "SPV_INTEL_float_controls2"
-%49 = OpExtInstImport "OpenCL.std"
-OpMemoryModel Physical64 OpenCL
-OpEntryPoint Kernel %1 "atom_inc"
-OpDecorate %40 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import
-OpDecorate %44 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_global_inc" Import
-%50 = OpTypeVoid
-%51 = OpTypeInt 32 0
-%52 = OpTypePointer Generic %51
-%53 = OpTypeFunction %51 %52 %51
-%54 = OpTypePointer CrossWorkgroup %51
-%55 = OpTypeFunction %51 %54 %51
-%56 = OpTypeInt 64 0
-%57 = OpTypeFunction %50 %56 %56
-%58 = OpTypePointer Function %56
-%59 = OpTypePointer Function %51
-%27 = OpConstant %51 101
-%28 = OpConstant %51 101
-%29 = OpConstant %56 4
-%31 = OpConstant %56 8
-%40 = OpFunction %51 None %53
-%42 = OpFunctionParameter %52
-%43 = OpFunctionParameter %51
-OpFunctionEnd
-%44 = OpFunction %51 None %55
-%46 = OpFunctionParameter %54
-%47 = OpFunctionParameter %51
-OpFunctionEnd
-%1 = OpFunction %50 None %57
-%9 = OpFunctionParameter %56
-%10 = OpFunctionParameter %56
-%39 = OpLabel
-%2 = OpVariable %58 Function
-%3 = OpVariable %58 Function
-%4 = OpVariable %58 Function
-%5 = OpVariable %58 Function
-%6 = OpVariable %59 Function
-%7 = OpVariable %59 Function
-%8 = OpVariable %59 Function
-OpStore %2 %9
-OpStore %3 %10
-%12 = OpLoad %56 %2
-%11 = OpCopyObject %56 %12
-OpStore %4 %11
-%14 = OpLoad %56 %3
-%13 = OpCopyObject %56 %14
-OpStore %5 %13
-%16 = OpLoad %56 %4
-%33 = OpConvertUToPtr %52 %16
-%15 = OpFunctionCall %51 %40 %33 %27
-OpStore %6 %15
-%18 = OpLoad %56 %4
-%34 = OpConvertUToPtr %54 %18
-%17 = OpFunctionCall %51 %44 %34 %28
-OpStore %7 %17
-%20 = OpLoad %56 %4
-%35 = OpConvertUToPtr %52 %20
-%19 = OpLoad %51 %35
-OpStore %8 %19
-%21 = OpLoad %56 %5
-%22 = OpLoad %51 %6
-%36 = OpConvertUToPtr %52 %21
-OpStore %36 %22
-%23 = OpLoad %56 %5
-%24 = OpLoad %51 %7
-%30 = OpIAdd %56 %23 %29
-%37 = OpConvertUToPtr %52 %30
-OpStore %37 %24
-%25 = OpLoad %56 %5
-%26 = OpLoad %51 %8
-%32 = OpIAdd %56 %25 %31
-%38 = OpConvertUToPtr %52 %32
-OpStore %38 %26
-OpReturn
-OpFunctionEnd \ No newline at end of file
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %47 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "atom_inc"
+ OpDecorate %38 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import
+ OpDecorate %42 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_global_inc" Import
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %51 = OpTypeFunction %uint %_ptr_Generic_uint %uint
+%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
+ %53 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint
+ %ulong = OpTypeInt 64 0
+ %55 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %uint_101 = OpConstant %uint 101
+ %uint_101_0 = OpConstant %uint 101
+ %ulong_4 = OpConstant %ulong 4
+ %ulong_8 = OpConstant %ulong 8
+ %38 = OpFunction %uint None %51
+ %40 = OpFunctionParameter %_ptr_Generic_uint
+ %41 = OpFunctionParameter %uint
+ OpFunctionEnd
+ %42 = OpFunction %uint None %53
+ %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
+ %45 = OpFunctionParameter %uint
+ OpFunctionEnd
+ %1 = OpFunction %void None %55
+ %9 = OpFunctionParameter %ulong
+ %10 = OpFunctionParameter %ulong
+ %37 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ %8 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %9
+ OpStore %3 %10
+ %11 = OpLoad %ulong %2
+ OpStore %4 %11
+ %12 = OpLoad %ulong %3
+ OpStore %5 %12
+ %14 = OpLoad %ulong %4
+ %31 = OpConvertUToPtr %_ptr_Generic_uint %14
+ %13 = OpFunctionCall %uint %38 %31 %uint_101
+ OpStore %6 %13
+ %16 = OpLoad %ulong %4
+ %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16
+ %15 = OpFunctionCall %uint %42 %32 %uint_101_0
+ OpStore %7 %15
+ %18 = OpLoad %ulong %4
+ %33 = OpConvertUToPtr %_ptr_Generic_uint %18
+ %17 = OpLoad %uint %33
+ OpStore %8 %17
+ %19 = OpLoad %ulong %5
+ %20 = OpLoad %uint %6
+ %34 = OpConvertUToPtr %_ptr_Generic_uint %19
+ OpStore %34 %20
+ %21 = OpLoad %ulong %5
+ %22 = OpLoad %uint %7
+ %28 = OpIAdd %ulong %21 %ulong_4
+ %35 = OpConvertUToPtr %_ptr_Generic_uint %28
+ OpStore %35 %22
+ %23 = OpLoad %ulong %5
+ %24 = OpLoad %uint %8
+ %30 = OpIAdd %ulong %23 %ulong_8
+ %36 = OpConvertUToPtr %_ptr_Generic_uint %30
+ OpStore %36 %24
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt
index cf6ff8b..143d0a5 100644
--- a/ptx/src/test/spirv_run/cvta.spvtxt
+++ b/ptx/src/test/spirv_run/cvta.spvtxt
@@ -7,48 +7,59 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %27 = OpExtInstImport "OpenCL.std"
+ %37 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvta"
%void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %30 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+ %41 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%float = OpTypeFloat 32
%_ptr_Function_float = OpTypePointer Function %float
+ %ulong = OpTypeInt 64 0
+%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
- %1 = OpFunction %void None %30
- %7 = OpFunctionParameter %ulong
- %8 = OpFunctionParameter %ulong
- %25 = OpLabel
- %2 = OpVariable %_ptr_Function_ulong Function
- %3 = OpVariable %_ptr_Function_ulong Function
- %4 = OpVariable %_ptr_Function_ulong Function
- %5 = OpVariable %_ptr_Function_ulong Function
+ %1 = OpFunction %void None %41
+ %17 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %18 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %35 = OpLabel
+ %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %7 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %8 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%6 = OpVariable %_ptr_Function_float Function
- OpStore %2 %7
- OpStore %3 %8
- %9 = OpLoad %ulong %2
- OpStore %4 %9
- %10 = OpLoad %ulong %3
- OpStore %5 %10
- %12 = OpLoad %ulong %4
- %20 = OpCopyObject %ulong %12
- %19 = OpCopyObject %ulong %20
- %11 = OpCopyObject %ulong %19
- OpStore %4 %11
- %14 = OpLoad %ulong %5
- %22 = OpCopyObject %ulong %14
- %21 = OpCopyObject %ulong %22
- %13 = OpCopyObject %ulong %21
- OpStore %5 %13
- %16 = OpLoad %ulong %4
- %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16
- %15 = OpLoad %float %23
- OpStore %6 %15
- %17 = OpLoad %ulong %5
- %18 = OpLoad %float %6
- %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17
- OpStore %24 %18
+ OpStore %2 %17
+ OpStore %3 %18
+ %10 = OpBitcast %_ptr_Function_ulong %2
+ %9 = OpLoad %ulong %10
+ %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %9
+ OpStore %7 %19
+ %12 = OpBitcast %_ptr_Function_ulong %3
+ %11 = OpLoad %ulong %12
+ %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %11
+ OpStore %8 %20
+ %21 = OpLoad %_ptr_CrossWorkgroup_uchar %7
+ %14 = OpConvertPtrToU %ulong %21
+ %30 = OpCopyObject %ulong %14
+ %29 = OpCopyObject %ulong %30
+ %13 = OpCopyObject %ulong %29
+ %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13
+ OpStore %7 %22
+ %23 = OpLoad %_ptr_CrossWorkgroup_uchar %8
+ %16 = OpConvertPtrToU %ulong %23
+ %32 = OpCopyObject %ulong %16
+ %31 = OpCopyObject %ulong %32
+ %15 = OpCopyObject %ulong %31
+ %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15
+ OpStore %8 %24
+ %26 = OpLoad %_ptr_CrossWorkgroup_uchar %7
+ %33 = OpBitcast %_ptr_CrossWorkgroup_float %26
+ %25 = OpLoad %float %33
+ OpStore %6 %25
+ %27 = OpLoad %_ptr_CrossWorkgroup_uchar %8
+ %28 = OpLoad %float %6
+ %34 = OpBitcast %_ptr_CrossWorkgroup_float %27
+ OpStore %34 %28
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt
index d979193..39f8683 100644
--- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt
+++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt
@@ -7,7 +7,7 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %48 = OpExtInstImport "OpenCL.std"
+ %46 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %14 "extern_shared_call" %1
OpDecorate %1 Alignment 4
@@ -18,78 +18,76 @@
%1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup
%uchar = OpTypeInt 8 0
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
- %55 = OpTypeFunction %void %_ptr_Workgroup_uchar
+ %53 = OpTypeFunction %void %_ptr_Workgroup_uchar
%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar
%ulong = OpTypeInt 64 0
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
%ulong_2 = OpConstant %ulong 2
- %62 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
+ %60 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
- %2 = OpFunction %void None %55
- %40 = OpFunctionParameter %_ptr_Workgroup_uchar
- %56 = OpLabel
- %41 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
+ %2 = OpFunction %void None %53
+ %38 = OpFunctionParameter %_ptr_Workgroup_uchar
+ %54 = OpLabel
+ %39 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
%3 = OpVariable %_ptr_Function_ulong Function
- OpStore %41 %40
+ OpStore %39 %38
OpBranch %13
%13 = OpLabel
- %42 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %41
- %5 = OpLoad %_ptr_Workgroup_uint %42
+ %40 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39
+ %5 = OpLoad %_ptr_Workgroup_uint %40
%11 = OpBitcast %_ptr_Workgroup_ulong %5
%4 = OpLoad %ulong %11
OpStore %3 %4
%7 = OpLoad %ulong %3
%6 = OpIAdd %ulong %7 %ulong_2
OpStore %3 %6
- %43 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %41
- %8 = OpLoad %_ptr_Workgroup_uint %43
+ %41 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39
+ %8 = OpLoad %_ptr_Workgroup_uint %41
%9 = OpLoad %ulong %3
%12 = OpBitcast %_ptr_Workgroup_ulong %8
OpStore %12 %9
OpReturn
OpFunctionEnd
- %14 = OpFunction %void None %62
+ %14 = OpFunction %void None %60
%20 = OpFunctionParameter %ulong
%21 = OpFunctionParameter %ulong
- %44 = OpFunctionParameter %_ptr_Workgroup_uchar
- %63 = OpLabel
- %45 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
+ %42 = OpFunctionParameter %_ptr_Workgroup_uchar
+ %61 = OpLabel
+ %43 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
%15 = OpVariable %_ptr_Function_ulong Function
%16 = OpVariable %_ptr_Function_ulong Function
%17 = OpVariable %_ptr_Function_ulong Function
%18 = OpVariable %_ptr_Function_ulong Function
%19 = OpVariable %_ptr_Function_ulong Function
- OpStore %45 %44
- OpBranch %38
- %38 = OpLabel
+ OpStore %43 %42
+ OpBranch %36
+ %36 = OpLabel
OpStore %15 %20
OpStore %16 %21
- %23 = OpLoad %ulong %15
- %22 = OpCopyObject %ulong %23
+ %22 = OpLoad %ulong %15
OpStore %17 %22
- %25 = OpLoad %ulong %16
- %24 = OpCopyObject %ulong %25
- OpStore %18 %24
- %27 = OpLoad %ulong %17
- %34 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %27
- %26 = OpLoad %ulong %34
- OpStore %19 %26
- %46 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %45
- %28 = OpLoad %_ptr_Workgroup_uint %46
- %29 = OpLoad %ulong %19
- %35 = OpBitcast %_ptr_Workgroup_ulong %28
- OpStore %35 %29
- %65 = OpFunctionCall %void %2 %44
- %47 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %45
- %31 = OpLoad %_ptr_Workgroup_uint %47
- %36 = OpBitcast %_ptr_Workgroup_ulong %31
- %30 = OpLoad %ulong %36
- OpStore %19 %30
- %32 = OpLoad %ulong %18
- %33 = OpLoad %ulong %19
- %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %32
- OpStore %37 %33
+ %23 = OpLoad %ulong %16
+ OpStore %18 %23
+ %25 = OpLoad %ulong %17
+ %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %25
+ %24 = OpLoad %ulong %32
+ OpStore %19 %24
+ %44 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43
+ %26 = OpLoad %_ptr_Workgroup_uint %44
+ %27 = OpLoad %ulong %19
+ %33 = OpBitcast %_ptr_Workgroup_ulong %26
+ OpStore %33 %27
+ %63 = OpFunctionCall %void %2 %42
+ %45 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43
+ %29 = OpLoad %_ptr_Workgroup_uint %45
+ %34 = OpBitcast %_ptr_Workgroup_ulong %29
+ %28 = OpLoad %ulong %34
+ OpStore %19 %28
+ %30 = OpLoad %ulong %18
+ %31 = OpLoad %ulong %19
+ %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %30
+ OpStore %35 %31
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index bd74508..f18b15c 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -133,6 +133,10 @@ test_ptx!(
[0b11111000_11000001_00100010_10100000u32, 16u32, 8u32],
[0b11000001u32]
);
+test_ptx!(stateful_ld_st_simple, [121u64], [121u64]);
+test_ptx!(stateful_ld_st_ntid, [123u64], [123u64]);
+test_ptx!(stateful_ld_st_ntid_chain, [12651u64], [12651u64]);
+test_ptx!(stateful_ld_st_ntid_sub, [96311u64], [96311u64]);
struct DisplayError<T: Debug> {
err: T,
@@ -292,7 +296,7 @@ fn test_spvtxt_assert<'a>(
rspirv::binary::parse_words(&parsed_spirv, &mut loader)?;
let spvtxt_mod = loader.module();
unsafe { spirv_tools::spvBinaryDestroy(spv_binary) };
- if !is_spirv_fn_equal(&spirv_module.spirv.functions[0], &spvtxt_mod.functions[0]) {
+ if !is_spirv_fns_equal(&spirv_module.spirv.functions, &spvtxt_mod.functions) {
// We could simply use ptx_mod.disassemble, but SPIRV-Tools text formattinmg is so much nicer
let spv_from_ptx_binary = spirv_module.spirv.assemble();
let mut spv_text: spirv_tools::spv_text = ptr::null_mut();
@@ -364,6 +368,18 @@ impl<T: Copy + Eq + Hash> EqMap<T> {
}
}
+fn is_spirv_fns_equal(fns1: &[Function], fns2: &[Function]) -> bool {
+ if fns1.len() != fns2.len() {
+ return false;
+ }
+ for (fn1, fn2) in fns1.iter().zip(fns2.iter()) {
+ if !is_spirv_fn_equal(fn1, fn2) {
+ return false;
+ }
+ }
+ true
+}
+
fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool {
let mut map = EqMap::new();
if !is_option_equal(&fn1.def, &fn2.def, &mut map, is_instr_equal) {
diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt
index 596cedc..5ce3689 100644
--- a/ptx/src/test/spirv_run/reg_local.spvtxt
+++ b/ptx/src/test/spirv_run/reg_local.spvtxt
@@ -22,7 +22,9 @@
%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%ulong_1 = OpConstant %ulong 1
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_0 = OpConstant %ulong 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_0_0 = OpConstant %ulong 0
%1 = OpFunction %void None %37
%8 = OpFunctionParameter %ulong
@@ -48,12 +50,12 @@
%14 = OpLoad %ulong %7
%26 = OpCopyObject %ulong %14
%19 = OpIAdd %ulong %26 %ulong_1
- %27 = OpBitcast %_ptr_Function_ulong %4
+ %27 = OpBitcast %_ptr_Generic_ulong %4
OpStore %27 %19
- %28 = OpBitcast %_ptr_Function_ulong %4
- %45 = OpBitcast %ulong %28
- %46 = OpIAdd %ulong %45 %ulong_0
- %21 = OpBitcast %_ptr_Function_ulong %46
+ %28 = OpBitcast %_ptr_Generic_ulong %4
+ %47 = OpBitcast %_ptr_Generic_uchar %28
+ %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0
+ %21 = OpBitcast %_ptr_Generic_ulong %48
%29 = OpLoad %ulong %21
%15 = OpCopyObject %ulong %29
OpStore %7 %15
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx b/ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx
new file mode 100644
index 0000000..1fc37d1
--- /dev/null
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx
@@ -0,0 +1,31 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry stateful_ld_st_ntid(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .b64 in_addr;
+ .reg .b64 out_addr;
+ .reg .u32 tid_32;
+ .reg .u64 tid_64;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ cvta.to.global.u64 in_addr, in_addr;
+ cvta.to.global.u64 out_addr, out_addr;
+
+ mov.u32 tid_32, %tid.x;
+ cvt.u64.u32 tid_64, tid_32;
+
+ add.u64 in_addr, in_addr, tid_64;
+ add.u64 out_addr, out_addr, tid_64;
+
+ ld.global.u64 temp, [in_addr];
+ st.global.u64 [out_addr], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
new file mode 100644
index 0000000..c53ad51
--- /dev/null
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
@@ -0,0 +1,89 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %49 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID
+ OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %v4uint = OpTypeVector %uint 4
+%_ptr_Input_v4uint = OpTypePointer Input %v4uint
+%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+ %56 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %ulong = OpTypeInt 64 0
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %1 = OpFunction %void None %56
+ %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %47 = OpLabel
+ %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %20
+ OpStore %3 %21
+ %13 = OpBitcast %_ptr_Function_ulong %2
+ %43 = OpLoad %ulong %13
+ %12 = OpCopyObject %ulong %43
+ %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %12
+ OpStore %10 %22
+ %15 = OpBitcast %_ptr_Function_ulong %3
+ %44 = OpLoad %ulong %15
+ %14 = OpCopyObject %ulong %44
+ %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14
+ OpStore %11 %23
+ %24 = OpLoad %_ptr_CrossWorkgroup_uchar %10
+ %17 = OpConvertPtrToU %ulong %24
+ %16 = OpCopyObject %ulong %17
+ %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %16
+ OpStore %10 %25
+ %26 = OpLoad %_ptr_CrossWorkgroup_uchar %11
+ %19 = OpConvertPtrToU %ulong %26
+ %18 = OpCopyObject %ulong %19
+ %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18
+ OpStore %11 %27
+ %29 = OpLoad %v4uint %gl_LocalInvocationID
+ %42 = OpCompositeExtract %uint %29 0
+ %28 = OpCopyObject %uint %42
+ OpStore %6 %28
+ %31 = OpLoad %uint %6
+ %61 = OpBitcast %uint %31
+ %30 = OpUConvert %ulong %61
+ OpStore %7 %30
+ %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10
+ %34 = OpLoad %ulong %7
+ %62 = OpBitcast %_ptr_CrossWorkgroup_uchar %33
+ %63 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %62 %34
+ %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %63
+ OpStore %10 %32
+ %36 = OpLoad %_ptr_CrossWorkgroup_uchar %11
+ %37 = OpLoad %ulong %7
+ %64 = OpBitcast %_ptr_CrossWorkgroup_uchar %36
+ %65 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %64 %37
+ %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %65
+ OpStore %11 %35
+ %39 = OpLoad %_ptr_CrossWorkgroup_uchar %10
+ %45 = OpBitcast %_ptr_CrossWorkgroup_ulong %39
+ %38 = OpLoad %ulong %45
+ OpStore %8 %38
+ %40 = OpLoad %_ptr_CrossWorkgroup_uchar %11
+ %41 = OpLoad %ulong %8
+ %46 = OpBitcast %_ptr_CrossWorkgroup_ulong %40
+ OpStore %46 %41
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx
new file mode 100644
index 0000000..ef7645d
--- /dev/null
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx
@@ -0,0 +1,35 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry stateful_ld_st_ntid_chain(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .b64 in_addr1;
+ .reg .b64 in_addr2;
+ .reg .b64 in_addr3;
+ .reg .b64 out_addr1;
+ .reg .b64 out_addr2;
+ .reg .b64 out_addr3;
+ .reg .u32 tid_32;
+ .reg .u64 tid_64;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr1, [input];
+ ld.param.u64 out_addr1, [output];
+
+ cvta.to.global.u64 in_addr2, in_addr1;
+ cvta.to.global.u64 out_addr2, out_addr1;
+
+ mov.u32 tid_32, %tid.x;
+ cvt.u64.u32 tid_64, tid_32;
+
+ add.u64 in_addr3, in_addr2, tid_64;
+ add.u64 out_addr3, out_addr2, tid_64;
+
+ ld.global.u64 temp, [in_addr3];
+ st.global.u64 [out_addr3], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt
new file mode 100644
index 0000000..5ba889c
--- /dev/null
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt
@@ -0,0 +1,93 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %57 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" %gl_LocalInvocationID
+ OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %v4uint = OpTypeVector %uint 4
+%_ptr_Input_v4uint = OpTypePointer Input %v4uint
+%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+ %64 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %ulong = OpTypeInt 64 0
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %1 = OpFunction %void None %64
+ %28 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %55 = OpLabel
+ %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %10 = OpVariable %_ptr_Function_uint Function
+ %11 = OpVariable %_ptr_Function_ulong Function
+ %12 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %28
+ OpStore %3 %29
+ %21 = OpBitcast %_ptr_Function_ulong %2
+ %51 = OpLoad %ulong %21
+ %20 = OpCopyObject %ulong %51
+ %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20
+ OpStore %14 %30
+ %23 = OpBitcast %_ptr_Function_ulong %3
+ %52 = OpLoad %ulong %23
+ %22 = OpCopyObject %ulong %52
+ %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22
+ OpStore %17 %31
+ %32 = OpLoad %_ptr_CrossWorkgroup_uchar %14
+ %25 = OpConvertPtrToU %ulong %32
+ %24 = OpCopyObject %ulong %25
+ %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24
+ OpStore %15 %33
+ %34 = OpLoad %_ptr_CrossWorkgroup_uchar %17
+ %27 = OpConvertPtrToU %ulong %34
+ %26 = OpCopyObject %ulong %27
+ %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26
+ OpStore %18 %35
+ %37 = OpLoad %v4uint %gl_LocalInvocationID
+ %50 = OpCompositeExtract %uint %37 0
+ %36 = OpCopyObject %uint %50
+ OpStore %10 %36
+ %39 = OpLoad %uint %10
+ %69 = OpBitcast %uint %39
+ %38 = OpUConvert %ulong %69
+ OpStore %11 %38
+ %41 = OpLoad %_ptr_CrossWorkgroup_uchar %15
+ %42 = OpLoad %ulong %11
+ %70 = OpBitcast %_ptr_CrossWorkgroup_uchar %41
+ %71 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %70 %42
+ %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %71
+ OpStore %16 %40
+ %44 = OpLoad %_ptr_CrossWorkgroup_uchar %18
+ %45 = OpLoad %ulong %11
+ %72 = OpBitcast %_ptr_CrossWorkgroup_uchar %44
+ %73 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %72 %45
+ %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %73
+ OpStore %19 %43
+ %47 = OpLoad %_ptr_CrossWorkgroup_uchar %16
+ %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %47
+ %46 = OpLoad %ulong %53
+ OpStore %12 %46
+ %48 = OpLoad %_ptr_CrossWorkgroup_uchar %19
+ %49 = OpLoad %ulong %12
+ %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %48
+ OpStore %54 %49
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx
new file mode 100644
index 0000000..018918c
--- /dev/null
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx
@@ -0,0 +1,35 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry stateful_ld_st_ntid_sub(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .b64 in_addr1;
+ .reg .b64 in_addr2;
+ .reg .b64 in_addr3;
+ .reg .b64 out_addr1;
+ .reg .b64 out_addr2;
+ .reg .b64 out_addr3;
+ .reg .u32 tid_32;
+ .reg .u64 tid_64;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr1, [input];
+ ld.param.u64 out_addr1, [output];
+
+ cvta.to.global.u64 in_addr2, in_addr1;
+ cvta.to.global.u64 out_addr2, out_addr1;
+
+ mov.u32 tid_32, %tid.x;
+ cvt.u64.u32 tid_64, tid_32;
+
+ sub.s64 in_addr3, in_addr2, tid_64;
+ sub.s64 out_addr3, out_addr2, tid_64;
+
+ ld.global.u64 temp, [in_addr3+-0];
+ st.global.u64 [out_addr3+-0], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt
new file mode 100644
index 0000000..3c215d4
--- /dev/null
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt
@@ -0,0 +1,105 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %65 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "stateful_ld_st_ntid_sub" %gl_LocalInvocationID
+ OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %v4uint = OpTypeVector %uint 4
+%_ptr_Input_v4uint = OpTypePointer Input %v4uint
+%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+ %72 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %ulong = OpTypeInt 64 0
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %ulong_0 = OpConstant %ulong 0
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %ulong_0_0 = OpConstant %ulong 0
+ %1 = OpFunction %void None %72
+ %30 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %31 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %63 = OpLabel
+ %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %10 = OpVariable %_ptr_Function_uint Function
+ %11 = OpVariable %_ptr_Function_ulong Function
+ %12 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %30
+ OpStore %3 %31
+ %21 = OpBitcast %_ptr_Function_ulong %2
+ %57 = OpLoad %ulong %21
+ %20 = OpCopyObject %ulong %57
+ %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20
+ OpStore %14 %32
+ %23 = OpBitcast %_ptr_Function_ulong %3
+ %58 = OpLoad %ulong %23
+ %22 = OpCopyObject %ulong %58
+ %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22
+ OpStore %17 %33
+ %34 = OpLoad %_ptr_CrossWorkgroup_uchar %14
+ %25 = OpConvertPtrToU %ulong %34
+ %24 = OpCopyObject %ulong %25
+ %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24
+ OpStore %15 %35
+ %36 = OpLoad %_ptr_CrossWorkgroup_uchar %17
+ %27 = OpConvertPtrToU %ulong %36
+ %26 = OpCopyObject %ulong %27
+ %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26
+ OpStore %18 %37
+ %39 = OpLoad %v4uint %gl_LocalInvocationID
+ %52 = OpCompositeExtract %uint %39 0
+ %38 = OpCopyObject %uint %52
+ OpStore %10 %38
+ %41 = OpLoad %uint %10
+ %77 = OpBitcast %uint %41
+ %40 = OpUConvert %ulong %77
+ OpStore %11 %40
+ %42 = OpLoad %ulong %11
+ %59 = OpCopyObject %ulong %42
+ %28 = OpSNegate %ulong %59
+ %44 = OpLoad %_ptr_CrossWorkgroup_uchar %15
+ %78 = OpBitcast %_ptr_CrossWorkgroup_uchar %44
+ %79 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %78 %28
+ %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %79
+ OpStore %16 %43
+ %45 = OpLoad %ulong %11
+ %60 = OpCopyObject %ulong %45
+ %29 = OpSNegate %ulong %60
+ %47 = OpLoad %_ptr_CrossWorkgroup_uchar %18
+ %80 = OpBitcast %_ptr_CrossWorkgroup_uchar %47
+ %81 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %80 %29
+ %46 = OpBitcast %_ptr_CrossWorkgroup_uchar %81
+ OpStore %19 %46
+ %49 = OpLoad %_ptr_CrossWorkgroup_uchar %16
+ %61 = OpBitcast %_ptr_CrossWorkgroup_ulong %49
+ %83 = OpBitcast %_ptr_CrossWorkgroup_uchar %61
+ %84 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %83 %ulong_0
+ %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %84
+ %48 = OpLoad %ulong %54
+ OpStore %12 %48
+ %50 = OpLoad %_ptr_CrossWorkgroup_uchar %19
+ %51 = OpLoad %ulong %12
+ %62 = OpBitcast %_ptr_CrossWorkgroup_ulong %50
+ %85 = OpBitcast %_ptr_CrossWorkgroup_uchar %62
+ %86 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %85 %ulong_0_0
+ %56 = OpBitcast %_ptr_CrossWorkgroup_ulong %86
+ OpStore %56 %51
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx
new file mode 100644
index 0000000..5650ada
--- /dev/null
+++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx
@@ -0,0 +1,25 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry stateful_ld_st_simple(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 in_addr2;
+ .reg .u64 out_addr2;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ cvta.to.global.u64 in_addr2, in_addr;
+ cvta.to.global.u64 out_addr2, out_addr;
+
+ ld.global.u64 temp, [in_addr2];
+ st.global.u64 [out_addr2], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt
new file mode 100644
index 0000000..cfd87eb
--- /dev/null
+++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt
@@ -0,0 +1,65 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %41 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "stateful_ld_st_simple"
+ %void = OpTypeVoid
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+ %45 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
+ %ulong = OpTypeInt 64 0
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %1 = OpFunction %void None %45
+ %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %22 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %39 = OpLabel
+ %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %9 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %21
+ OpStore %3 %22
+ %14 = OpBitcast %_ptr_Function_ulong %2
+ %13 = OpLoad %ulong %14
+ %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13
+ OpStore %9 %23
+ %16 = OpBitcast %_ptr_Function_ulong %3
+ %15 = OpLoad %ulong %16
+ %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15
+ OpStore %10 %24
+ %25 = OpLoad %_ptr_CrossWorkgroup_uchar %9
+ %18 = OpConvertPtrToU %ulong %25
+ %34 = OpCopyObject %ulong %18
+ %33 = OpCopyObject %ulong %34
+ %17 = OpCopyObject %ulong %33
+ %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %17
+ OpStore %11 %26
+ %27 = OpLoad %_ptr_CrossWorkgroup_uchar %10
+ %20 = OpConvertPtrToU %ulong %27
+ %36 = OpCopyObject %ulong %20
+ %35 = OpCopyObject %ulong %36
+ %19 = OpCopyObject %ulong %35
+ %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19
+ OpStore %12 %28
+ %30 = OpLoad %_ptr_CrossWorkgroup_uchar %11
+ %37 = OpBitcast %_ptr_CrossWorkgroup_ulong %30
+ %29 = OpLoad %ulong %37
+ OpStore %8 %29
+ %31 = OpLoad %_ptr_CrossWorkgroup_uchar %12
+ %32 = OpLoad %ulong %8
+ %38 = OpBitcast %_ptr_CrossWorkgroup_ulong %31
+ OpStore %38 %32
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index f0a3187..328bf30 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,7 +1,7 @@
use crate::ast;
use half::f16;
use rspirv::dr;
-use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
+use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
@@ -72,15 +72,13 @@ impl From<ast::PointerType> for ast::Type {
}
impl ast::Type {
- fn pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
+ fn param_pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
Ok(match self {
ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
ast::Type::Vector(t, len) => {
ast::Type::Pointer(ast::PointerType::Vector(t, len), space)
}
- ast::Type::Array(t, dims) => {
- ast::Type::Pointer(ast::PointerType::Array(t, dims), space)
- }
+ ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
}
@@ -726,7 +724,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
multi_hash_map_append(&mut directly_called_by, call.func, call_key);
Statement::Call(call)
}
- statement => statement.map_id(&mut |id| {
+ statement => statement.map_id(&mut |id, _| {
if extern_shared_decls.contains_key(&id) {
methods_using_extern_shared.insert(call_key);
}
@@ -843,7 +841,7 @@ fn replace_uses_of_shared_memory<'a>(
result.push(Statement::Call(call))
}
statement => {
- let new_statement = statement.map_id(&mut |id| {
+ let new_statement = statement.map_id(&mut |id, _| {
if let Some(typ) = extern_shared_decls.get(&id) {
let replacement_id = new_id();
if *typ != ast::SizedScalarType::B8 {
@@ -859,6 +857,8 @@ fn replace_uses_of_shared_memory<'a>(
ast::LdStateSpace::Shared,
),
kind: ConversionKind::PtrToPtr { spirv_ptr: true },
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
}));
}
replacement_id
@@ -971,7 +971,7 @@ fn compute_denorm_information<'input>(
Statement::Undef(_, _) => {}
Statement::Label(_) => {}
Statement::Variable(_) => {}
- Statement::PtrAdd { .. } => {}
+ Statement::PtrAccess { .. } => {}
}
}
denorm_methods.insert(method_key, flush_counter);
@@ -1022,15 +1022,10 @@ fn emit_builtins(
builder,
SpirvType::Pointer(
Box::new(SpirvType::from(reg.get_type())),
- spirv::StorageClass::UniformConstant,
+ spirv::StorageClass::Input,
),
);
- builder.variable(
- result_type,
- Some(*id),
- spirv::StorageClass::UniformConstant,
- None,
- );
+ builder.variable(result_type, Some(*id), spirv::StorageClass::Input, None);
builder.decorate(
*id,
spirv::Decoration::BuiltIn,
@@ -1192,11 +1187,31 @@ fn translate_variable<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
var: ast::Variable<ast::VariableType, &'a str>,
) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
- let (state_space, typ) = var.v_type.to_type();
+ let (space, var_type) = var.v_type.to_type();
+ let mut is_variable = false;
+ let var_type = match space {
+ ast::StateSpace::Reg => {
+ is_variable = true;
+ var_type
+ }
+ ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?,
+ ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
+ ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
+ ast::StateSpace::Shared => {
+ // If it's a pointer it will be translated to a method parameter later
+ if let ast::Type::Pointer(..) = var_type {
+ is_variable = true;
+ var_type
+ } else {
+ var_type.param_pointer_to(ast::LdStateSpace::Shared)?
+ }
+ }
+ ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
+ };
Ok(ast::Variable {
align: var.align,
v_type: var.v_type,
- name: id_defs.get_or_add_def_typed(var.name, (state_space.into(), typ)),
+ name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
array_init: var.array_init,
})
}
@@ -1218,10 +1233,8 @@ fn expand_kernel_params<'a, 'b>(
Ok(ast::KernelArgument {
name: fn_resolver.add_def(
a.name,
- Some((
- StateSpace::Param,
- ast::Type::from(a.v_type.clone()).pointer_to(ast::LdStateSpace::Param)?,
- )),
+ Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?),
+ false,
),
v_type: a.v_type.clone(),
align: a.align,
@@ -1236,14 +1249,13 @@ fn expand_fn_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
args.map(|a| {
- let var_type = a.v_type.to_func_type();
- let ss = match a.v_type {
- ast::FnArgumentType::Reg(_) => StateSpace::Reg,
- ast::FnArgumentType::Param(_) => StateSpace::Param,
- ast::FnArgumentType::Shared => StateSpace::Shared,
+ let is_variable = match a.v_type {
+ ast::FnArgumentType::Reg(_) => true,
+ _ => false,
};
+ let var_type = a.v_type.to_func_type();
Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some((ss, var_type))),
+ name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
v_type: a.v_type.clone(),
align: a.align,
array_init: Vec::new(),
@@ -1274,12 +1286,18 @@ fn to_ssa<'input, 'b>(
};
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
let mut numeric_id_defs = id_defs.finish();
- let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
+ let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
+ let typed_statements =
+ convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
+ let ssa_statements = insert_mem_ssa_statements(
+ typed_statements,
+ &mut numeric_id_defs,
+ &f_args,
+ &mut spirv_decl,
+ )?;
let mut numeric_id_defs = numeric_id_defs.finish();
- let ssa_statements =
- insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?;
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
@@ -1421,62 +1439,23 @@ fn convert_to_typed_statements(
};
result.push(Statement::Call(resolved_call));
}
- // Supported ld/st:
- // global: only compatible with reg b64/u64/s64 source/dest
- // generic: compatible with global/local sources
- // param: compiled as mov
- // local compiled as mov
- // We would like to convert ld/st local/param to movs here,
- // but they have different semantics for implicit conversions
- // For now, we convert generic ld from local params to ld.local.
- // This way, we can rely on further stages of the compilation on
- // ld.generic & ld.global having bytes address source
- // One complication: immediate address is only allowed in local,
- // It is not supported in generic ld
- // ld.local foo, [1];
- ast::Instruction::Ld(mut d, arg) => {
- match arg.src.underlying() {
- None => {}
- Some(u) => {
- let (ss, _) = id_defs.get_typed(*u)?;
- match (d.state_space, ss) {
- (ast::LdStateSpace::Generic, StateSpace::Local) => {
- d.state_space = ast::LdStateSpace::Local;
- }
- _ => {}
- };
- }
- };
+ ast::Instruction::Ld(d, arg) => {
result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast())));
}
- ast::Instruction::St(mut d, arg) => {
- match arg.src1.underlying() {
- None => {}
- Some(u) => {
- let (ss, _) = id_defs.get_typed(*u)?;
- match (d.state_space, ss) {
- (ast::StStateSpace::Generic, StateSpace::Local) => {
- d.state_space = ast::StStateSpace::Local;
- }
- _ => (),
- };
- }
- };
+ ast::Instruction::St(d, arg) => {
result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast())));
}
ast::Instruction::Mov(mut d, args) => match args {
ast::Arg2Mov::Normal(arg) => {
if let Some(src_id) = arg.src.single_underlying() {
- let (scope, _) = id_defs.get_typed(*src_id)?;
- d.src_is_address = match scope {
- StateSpace::Reg => false,
- StateSpace::Const
- | StateSpace::Global
- | StateSpace::Local
- | StateSpace::Shared
- | StateSpace::Param
- | StateSpace::ParamReg => true,
+ let (typ, _) = id_defs.get_typed(*src_id)?;
+ let take_address = match typ {
+ ast::Type::Scalar(_) => false,
+ ast::Type::Vector(_, _) => false,
+ ast::Type::Array(_, _) => true,
+ ast::Type::Pointer(_, _) => true,
};
+ d.src_is_address = take_address;
}
result.push(Statement::Instruction(ast::Instruction::Mov(
d,
@@ -1486,7 +1465,7 @@ fn convert_to_typed_statements(
ast::Arg2Mov::Member(args) => {
if let Some(dst_typ) = args.vector_dst() {
match id_defs.get_typed(*dst_typ)? {
- (_, ast::Type::Vector(_, len)) => {
+ (ast::Type::Vector(_, len), _) => {
d.dst_width = len;
}
_ => return Err(TranslateError::MismatchedType),
@@ -1494,7 +1473,7 @@ fn convert_to_typed_statements(
};
if let Some((src_typ, _)) = args.vector_src() {
match id_defs.get_typed(*src_typ)? {
- (_, ast::Type::Vector(_, len)) => {
+ (ast::Type::Vector(_, len), _) => {
d.src_width = len;
}
_ => return Err(TranslateError::MismatchedType),
@@ -1650,17 +1629,8 @@ fn convert_to_typed_statements(
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
- Statement::LoadVar(a, t) => result.push(Statement::LoadVar(a, t)),
- Statement::StoreVar(a, t) => result.push(Statement::StoreVar(a, t)),
- Statement::Call(c) => result.push(Statement::Call(c.cast())),
- Statement::Composite(c) => result.push(Statement::Composite(c)),
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
- Statement::Conversion(c) => result.push(Statement::Conversion(c)),
- Statement::Constant(c) => result.push(Statement::Constant(c)),
- Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Undef(_, _) | Statement::PtrAdd { .. } => {
- return Err(TranslateError::Unreachable)
- }
+ _ => return Err(TranslateError::Unreachable),
}
}
Ok(result)
@@ -1689,14 +1659,14 @@ fn to_ptx_impl_atomic_call(
};
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_id(None);
+ let fn_id = id_defs.new_non_variable(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
fn_id,
@@ -1707,7 +1677,7 @@ fn to_ptx_impl_atomic_call(
ast::SizedScalarType::U32,
ptr_space,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
@@ -1715,7 +1685,7 @@ fn to_ptx_impl_atomic_call(
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
],
@@ -1779,12 +1749,12 @@ fn to_ptx_impl_bfe_call(
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_id(None);
+ let fn_id = id_defs.new_non_variable(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
fn_id,
@@ -1792,7 +1762,7 @@ fn to_ptx_impl_bfe_call(
ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
@@ -1800,7 +1770,7 @@ fn to_ptx_impl_bfe_call(
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
@@ -1808,7 +1778,7 @@ fn to_ptx_impl_bfe_call(
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
],
@@ -1893,10 +1863,10 @@ fn normalize_labels(
| Statement::Constant(_)
| Statement::Label(_)
| Statement::Undef(_, _)
- | Statement::PtrAdd { .. } => {}
+ | Statement::PtrAccess { .. } => {}
}
}
- iter::once(Statement::Label(id_def.new_id(None)))
+ iter::once(Statement::Label(id_def.new_non_variable(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
@@ -1907,15 +1877,15 @@ fn normalize_labels(
fn normalize_predicates(
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
-) -> Vec<UnconditionalStatement> {
+) -> Result<Vec<UnconditionalStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
- let if_true = id_def.new_id(None);
- let if_false = id_def.new_id(None);
+ let if_true = id_def.new_non_variable(None);
+ let if_false = id_def.new_non_variable(None);
let folded_bra = match &inst {
ast::Instruction::Bra(_, arg) => Some(arg.src),
_ => None,
@@ -1940,20 +1910,25 @@ fn normalize_predicates(
}
Statement::Variable(var) => result.push(Statement::Variable(var)),
// Blocks are flattened when resolving ids
- _ => unreachable!(),
+ _ => return Err(TranslateError::Unreachable),
}
}
- result
+ Ok(result)
}
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
- id_def: &mut MutableNumericIdResolver,
+ id_def: &mut NumericIdResolver,
+ ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
fn_decl: &mut SpirvMethodDecl,
) -> Result<Vec<TypedStatement>, TranslateError> {
+ let is_func = match ast_fn_decl {
+ ast::MethodDecl::Func(..) => true,
+ ast::MethodDecl::Kernel { .. } => false,
+ };
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
- match type_to_variable_type(&arg.v_type)? {
+ match type_to_variable_type(&arg.v_type, is_func)? {
Some(var_type) => {
result.push(Statement::Variable(ast::Variable {
align: arg.align,
@@ -1965,25 +1940,25 @@ fn insert_mem_ssa_statements<'a, 'b>(
None => return Err(TranslateError::Unreachable),
}
}
- for arg in fn_decl.input.iter_mut() {
- match type_to_variable_type(&arg.v_type)? {
+ for spirv_arg in fn_decl.input.iter_mut() {
+ match type_to_variable_type(&spirv_arg.v_type, is_func)? {
Some(var_type) => {
- let typ = arg.v_type.clone();
- let new_id = id_def.new_id(typ.clone());
+ let typ = spirv_arg.v_type.clone();
+ let new_id = id_def.new_non_variable(Some(typ.clone()));
result.push(Statement::Variable(ast::Variable {
- align: arg.align,
+ align: spirv_arg.align,
v_type: var_type,
- name: arg.name,
- array_init: arg.array_init.clone(),
+ name: spirv_arg.name,
+ array_init: spirv_arg.array_init.clone(),
}));
result.push(Statement::StoreVar(
ast::Arg2St {
- src1: arg.name,
+ src1: spirv_arg.name,
src2: new_id,
},
typ,
));
- arg.name = new_id;
+ spirv_arg.name = new_id;
}
None => {}
}
@@ -1997,8 +1972,8 @@ fn insert_mem_ssa_statements<'a, 'b>(
ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
if let &[out_param] = &fn_decl.output.as_slice() {
- let typ = id_def.get_typed(out_param.name)?;
- let new_id = id_def.new_id(typ.clone());
+ let (typ, _) = id_def.get_typed(out_param.name)?;
+ let new_id = id_def.new_non_variable(Some(typ.clone()));
result.push(Statement::LoadVar(
ast::Arg2 {
dst: new_id,
@@ -2014,7 +1989,8 @@ fn insert_mem_ssa_statements<'a, 'b>(
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
Statement::Conditional(mut bra) => {
- let generated_id = id_def.new_id(ast::Type::Scalar(ast::ScalarType::Pred));
+ let generated_id =
+ id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
@@ -2025,21 +2001,23 @@ fn insert_mem_ssa_statements<'a, 'b>(
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
+ Statement::Conversion(conv) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, conv)?
+ }
+ Statement::PtrAccess(ptr_access) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
+ }
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
- Statement::LoadVar(_, _)
- | Statement::StoreVar(_, _)
- | Statement::Conversion(_)
- | Statement::RetValue(_, _)
- | Statement::Constant(_)
- | Statement::Undef(_, _)
- | Statement::PtrAdd { .. } => {}
- Statement::Composite(_) => todo!(),
+ _ => return Err(TranslateError::Unreachable),
}
}
Ok(result)
}
-fn type_to_variable_type(t: &ast::Type) -> Result<Option<ast::VariableType>, TranslateError> {
+fn type_to_variable_type(
+ t: &ast::Type,
+ is_func: bool,
+) -> Result<Option<ast::VariableType>, TranslateError> {
Ok(match t {
ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
@@ -2054,7 +2032,22 @@ fn type_to_variable_type(t: &ast::Type) -> Result<Option<ast::VariableType>, Tra
.map_err(|_| TranslateError::MismatchedType)?,
len.clone(),
))),
- ast::Type::Pointer(_, _) => None,
+ ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
+ if is_func {
+ return Ok(None);
+ }
+ Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
+ scalar_type
+ .clone()
+ .try_into()
+ .map_err(|_| TranslateError::Unreachable)?,
+ (*space)
+ .try_into()
+ .map_err(|_| TranslateError::Unreachable)?,
+ )))
+ }
+ ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
+ _ => return Err(TranslateError::Unreachable),
})
}
@@ -2105,34 +2098,28 @@ impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded
}
fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
- id_def: &mut MutableNumericIdResolver,
+ id_def: &mut NumericIdResolver,
result: &mut Vec<TypedStatement>,
stmt: F,
) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
- let new_statement =
- stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, instr_type| {
- if instr_type.is_none() || desc.sema == ArgumentSemantics::RegisterPointer {
+ let new_statement = stmt.visit_variable(
+ &mut |desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ if expected_type.is_none() {
return Ok(desc.op);
- }
- let id_type = match (id_def.get_typed(desc.op)?, desc.sema) {
- (_, ArgumentSemantics::Address) => return Ok(desc.op),
- (t, ArgumentSemantics::RegisterPointer)
- | (t, ArgumentSemantics::Default)
- | (t, ArgumentSemantics::DefaultRelaxed)
- | (t, ArgumentSemantics::PhysicalPointer) => t,
};
- if let ast::Type::Array(_, _) = id_type {
+ let (var_type, is_variable) = id_def.get_typed(desc.op)?;
+ if !is_variable {
return Ok(desc.op);
}
- let generated_id = id_def.new_id(id_type.clone());
+ let generated_id = id_def.new_non_variable(Some(var_type.clone()));
if !desc.is_dst {
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: desc.op,
},
- id_type,
+ var_type,
));
} else {
post_statements.push(Statement::StoreVar(
@@ -2140,11 +2127,12 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
src1: desc.op,
src2: generated_id,
},
- id_type,
+ var_type,
));
}
Ok(generated_id)
- })?;
+ },
+ )?;
result.push(new_statement);
result.append(&mut post_statements);
Ok(())
@@ -2180,65 +2168,21 @@ fn expand_arguments<'a, 'b>(
name,
array_init,
})),
- Statement::PtrAdd {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- constant_src,
- } => {
+ Statement::PtrAccess(ptr_access) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let sema = match state_space {
- ast::LdStateSpace::Const
- | ast::LdStateSpace::Global
- | ast::LdStateSpace::Shared
- | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
- ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
- ArgumentSemantics::RegisterPointer
- }
- };
- let ptr_type = ast::Type::Pointer(underlying_type.clone(), state_space);
- let new_dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema,
- },
- Some(&ptr_type),
- )?;
- let new_ptr_src = visitor.id(
- ArgumentDescriptor {
- op: ptr_src,
- is_dst: false,
- sema,
- },
- Some(&ptr_type),
- )?;
- let new_constant_src = visitor.id(
- ArgumentDescriptor {
- op: constant_src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Scalar(ast::ScalarType::S64)),
- )?;
- result.push(Statement::PtrAdd {
- underlying_type,
- state_space,
- dst: new_dst,
- ptr_src: new_ptr_src,
- constant_src: new_constant_src,
- })
+ let (new_inst, post_stmts) = (ptr_access.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::PtrAccess(new_inst));
+ result.extend(post_stmts);
}
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Composite(_)
- | Statement::Conversion(_)
- | Statement::Constant(_)
- | Statement::Undef(_, _) => unreachable!(),
+ Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
+ Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => {
+ return Err(TranslateError::Unreachable)
+ }
}
}
Ok(result)
@@ -2270,7 +2214,8 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
scalar_sema_override: Option<ArgumentSemantics>,
composite_src: (spirv::Word, u8),
) -> spirv::Word {
- let new_id = scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Scalar(typ.0)));
+ let new_id =
+ scalar_dst.unwrap_or_else(|| id_def.new_non_variable(ast::Type::Scalar(typ.0)));
func.push(Statement::Composite(CompositeRead {
typ: typ.0,
dst: new_id,
@@ -2301,20 +2246,20 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
ast::Type::Pointer(underlying_type, state_space) => {
let reg_typ = self.id_def.get_typed(reg)?;
if let ast::Type::Pointer(_, _) = reg_typ {
- let id_constant_stmt = self.id_def.new_id(typ.clone());
+ let id_constant_stmt = self.id_def.new_non_variable(typ.clone());
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
- let dst = self.id_def.new_id(typ.clone());
- self.func.push(Statement::PtrAdd {
+ let dst = self.id_def.new_non_variable(typ.clone());
+ self.func.push(Statement::PtrAccess(PtrAccess {
underlying_type: underlying_type.clone(),
state_space: *state_space,
dst,
ptr_src: reg,
- constant_src: id_constant_stmt,
- });
+ offset_src: id_constant_stmt,
+ }));
return Ok(dst);
} else {
add_type = self.id_def.get_typed(reg)?;
@@ -2346,8 +2291,8 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
} else {
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
};
- let id_constant_stmt = self.id_def.new_id(add_type.clone());
- let result_id = self.id_def.new_id(add_type);
+ let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
+ let result_id = self.id_def.new_non_variable(add_type);
// TODO: check for edge cases around min value/max value/wrapping
if offset < 0 && kind != ScalarKind::Signed {
self.func.push(Statement::Constant(ConstantDefinition {
@@ -2395,7 +2340,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
} else {
todo!()
};
- let id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
@@ -2430,10 +2375,10 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
) -> Result<spirv::Word, TranslateError> {
let (scalar_type, vec_len) = typ.get_vector()?;
if !desc.is_dst {
- let mut new_id = self.id_def.new_id(typ.clone());
+ let mut new_id = self.id_def.new_non_variable(typ.clone());
self.func.push(Statement::Undef(typ.clone(), new_id));
for (idx, id) in desc.op.iter().enumerate() {
- let newer_id = self.id_def.new_id(typ.clone());
+ let newer_id = self.id_def.new_non_variable(typ.clone());
self.func.push(Statement::Instruction(ast::Instruction::Mov(
ast::MovDetails {
typ: ast::Type::Scalar(scalar_type),
@@ -2452,7 +2397,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
}
Ok(new_id)
} else {
- let new_id = self.id_def.new_id(typ.clone());
+ let new_id = self.id_def.new_non_variable(typ.clone());
for (idx, id) in desc.op.iter().enumerate() {
Self::insert_composite_read(
&mut self.post_stmts,
@@ -2597,13 +2542,13 @@ fn insert_implicit_conversions(
should_bitcast_wrapper,
None,
)?,
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
- constant_src,
- } => {
+ offset_src: constant_src,
+ }) => {
let visit_desc = VisitArgumentDescriptor {
desc: ArgumentDescriptor {
op: ptr_src,
@@ -2611,12 +2556,14 @@ fn insert_implicit_conversions(
sema: ArgumentSemantics::PhysicalPointer,
},
typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
- stmt_ctor: |new_ptr_src| Statement::PtrAdd {
- underlying_type,
- state_space,
- dst,
- ptr_src: new_ptr_src,
- constant_src,
+ stmt_ctor: |new_ptr_src| {
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src: new_ptr_src,
+ offset_src: constant_src,
+ })
},
};
insert_implicit_conversions_impl(
@@ -2628,6 +2575,7 @@ fn insert_implicit_conversions(
)?;
}
s @ Statement::Conditional(_)
+ | s @ Statement::Conversion(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
@@ -2635,7 +2583,6 @@ fn insert_implicit_conversions(
| s @ Statement::StoreVar(_, _)
| s @ Statement::Undef(_, _)
| s @ Statement::RetValue(_, _) => result.push(s),
- Statement::Conversion(_) => unreachable!(),
}
}
Ok(result)
@@ -2688,7 +2635,7 @@ fn insert_implicit_conversions_impl(
};
let mut from = instr_type.clone();
let mut to = operand_type;
- let mut src = id_def.new_id(instr_type.clone());
+ let mut src = id_def.new_non_variable(instr_type.clone());
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
@@ -2701,6 +2648,8 @@ fn insert_implicit_conversions_impl(
from,
to,
kind: conv_kind,
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
}));
result
}
@@ -3242,21 +3191,33 @@ fn emit_function_body_ops(
let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
builder.undef(result_type, Some(*id));
}
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
- constant_src,
- } => {
- let s64_type = map.get_or_add_scalar(builder, ast::ScalarType::S64);
- let ptr_as_s64 = builder.bitcast(s64_type, None, *ptr_src)?;
- let added_ptr = builder.i_add(s64_type, None, ptr_as_s64, *constant_src)?;
+ offset_src,
+ }) => {
+ let u8_pointer = map.get_or_add(
+ builder,
+ SpirvType::from(ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ *state_space,
+ )),
+ );
let result_type = map.get_or_add(
builder,
SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)),
);
- builder.bitcast(result_type, Some(*dst), added_ptr)?;
+ let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
+ let temp = builder.in_bounds_ptr_access_chain(
+ u8_pointer,
+ None,
+ ptr_src_u8,
+ *offset_src,
+ &[],
+ )?;
+ builder.bitcast(result_type, Some(*dst), temp)?;
}
}
}
@@ -3745,6 +3706,8 @@ fn emit_cvt(
src_t.kind(),
)),
kind: ConversionKind::Default,
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
};
emit_implicit_conversion(builder, map, &cv)?;
new_dst
@@ -4117,6 +4080,8 @@ fn emit_implicit_conversion(
from: wide_bit_type,
to: cv.to.clone(),
kind: ConversionKind::Default,
+ src_sema: cv.src_sema,
+ dst_sema: cv.dst_sema,
},
)?;
}
@@ -4156,7 +4121,7 @@ fn normalize_identifiers<'a, 'b>(
for s in func.iter() {
match s {
ast::Statement::Label(id) => {
- id_defs.add_def(*id, None);
+ id_defs.add_def(*id, None, false);
}
_ => (),
}
@@ -4189,23 +4154,35 @@ fn expand_map_variables<'a, 'b>(
i.map_variable(&mut |id| id_defs.get_id(id))?,
))),
ast::Statement::Variable(var) => {
- let ss = match var.var.v_type {
- ast::VariableType::Reg(_) => StateSpace::Reg,
- ast::VariableType::Global(_) => StateSpace::Global,
- ast::VariableType::Shared(_) => StateSpace::Shared,
- ast::VariableType::Param(_) => StateSpace::ParamReg,
- ast::VariableType::Local(_) => StateSpace::Local,
- };
let mut var_type = ast::Type::from(var.var.v_type.clone());
+ let mut is_variable = false;
var_type = match var.var.v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Shared(_) => var_type,
- ast::VariableType::Global(_) => var_type.pointer_to(ast::LdStateSpace::Global)?,
- ast::VariableType::Param(_) => var_type.pointer_to(ast::LdStateSpace::Param)?,
- ast::VariableType::Local(_) => var_type.pointer_to(ast::LdStateSpace::Local)?,
+ ast::VariableType::Reg(_) => {
+ is_variable = true;
+ var_type
+ }
+ ast::VariableType::Shared(_) => {
+ // If it's a pointer it will be translated to a method parameter later
+ if let ast::Type::Pointer(..) = var_type {
+ is_variable = true;
+ var_type
+ } else {
+ var_type.param_pointer_to(ast::LdStateSpace::Shared)?
+ }
+ }
+ ast::VariableType::Global(_) => {
+ var_type.param_pointer_to(ast::LdStateSpace::Global)?
+ }
+ ast::VariableType::Param(_) => {
+ var_type.param_pointer_to(ast::LdStateSpace::Param)?
+ }
+ ast::VariableType::Local(_) => {
+ var_type.param_pointer_to(ast::LdStateSpace::Local)?
+ }
};
match var.count {
Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, ss, var_type) {
+ for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
@@ -4215,7 +4192,7 @@ fn expand_map_variables<'a, 'b>(
}
}
None => {
- let new_id = id_defs.add_def(var.var.name, Some((ss, var_type)));
+ let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable);
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
@@ -4229,6 +4206,384 @@ fn expand_map_variables<'a, 'b>(
Ok(())
}
+// TODO: detect more patterns (mov, call via reg, call via param)
+// TODO: don't convert to ptr if the register is not ultimately used for ld/st
+// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
+// argument expansion
+// TODO: propagate through calls?
+fn convert_to_stateful_memory_access<'a>(
+ func_args: &mut SpirvMethodDecl,
+ func_body: Vec<TypedStatement>,
+ id_defs: &mut NumericIdResolver<'a>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let func_args_64bit = func_args
+ .input
+ .iter()
+ .filter_map(|arg| match arg.v_type {
+ ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
+ _ => None,
+ })
+ .collect::<HashSet<_>>();
+ let mut stateful_markers = Vec::new();
+ let mut stateful_init_reg = MultiHashMap::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Cvta(
+ ast::CvtaDetails {
+ to: ast::CvtaStateSpace::Global,
+ size: ast::CvtaSize::U64,
+ from: ast::CvtaStateSpace::Generic,
+ },
+ arg,
+ )) => {
+ if let Some(src) = arg.src.underlying() {
+ if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, arg.dst) {
+ stateful_markers.push((arg.dst, *src));
+ }
+ }
+ }
+ Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ typ: ast::LdStType::Scalar(ast::LdStScalarType::U64),
+ ..
+ },
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ typ: ast::LdStType::Scalar(ast::LdStScalarType::S64),
+ ..
+ },
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ typ: ast::LdStType::Scalar(ast::LdStScalarType::B64),
+ ..
+ },
+ arg,
+ )) => {
+ if let (ast::IdOrVector::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) {
+ if func_args_64bit.contains(src) {
+ multi_hash_map_append(&mut stateful_init_reg, *dst, *src);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ let mut func_args_ptr = HashSet::new();
+ let mut regs_ptr_current = HashSet::new();
+ for (dst, src) in stateful_markers {
+ if let Some(func_args) = stateful_init_reg.get(&src) {
+ for a in func_args {
+ func_args_ptr.insert(*a);
+ regs_ptr_current.insert(src);
+ regs_ptr_current.insert(dst);
+ }
+ }
+ }
+ // BTreeSet here to have a stable order of iteration,
+ // unfortunately our tests rely on it
+ let mut regs_ptr_seen = BTreeSet::new();
+ while regs_ptr_current.len() > 0 {
+ let mut regs_ptr_new = HashSet::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ )) => {
+ if let Some(src1) = arg.src1.underlying() {
+ if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) {
+ regs_ptr_new.insert(arg.dst);
+ }
+ } else if let Some(src2) = arg.src2.underlying() {
+ if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) {
+ regs_ptr_new.insert(arg.dst);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ for id in regs_ptr_current {
+ regs_ptr_seen.insert(id);
+ }
+ regs_ptr_current = regs_ptr_new;
+ }
+ drop(regs_ptr_current);
+ let mut remapped_ids = HashMap::new();
+ let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
+ for reg in regs_ptr_seen {
+ let new_id = id_defs.new_variable(ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ));
+ result.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: new_id,
+ array_init: Vec::new(),
+ v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
+ ast::SizedScalarType::U8,
+ ast::PointerStateSpace::Global,
+ )),
+ }));
+ remapped_ids.insert(reg, new_id);
+ }
+ for statement in func_body {
+ match statement {
+ l @ Statement::Label(_) => result.push(l),
+ c @ Statement::Conditional(_) => result.push(c),
+ Statement::Variable(var) => {
+ if !remapped_ids.contains_key(&var.name) {
+ result.push(Statement::Variable(var));
+ }
+ }
+ Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ )) if is_add_ptr_direct(&remapped_ids, &arg) => {
+ let (ptr, offset) = match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => {
+ (remapped_ids.get(src1).unwrap(), arg.src2)
+ }
+ Some(src2) if remapped_ids.contains_key(src2) => {
+ (remapped_ids.get(src2).unwrap(), arg.src1)
+ }
+ _ => return Err(TranslateError::Unreachable),
+ };
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
+ state_space: ast::LdStateSpace::Global,
+ dst: *remapped_ids.get(&arg.dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: offset,
+ }))
+ }
+ Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ )) if is_add_ptr_direct(&remapped_ids, &arg) => {
+ let (ptr, offset) = match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => {
+ (remapped_ids.get(src1).unwrap(), arg.src2)
+ }
+ Some(src2) if remapped_ids.contains_key(src2) => {
+ (remapped_ids.get(src2).unwrap(), arg.src1)
+ }
+ _ => return Err(TranslateError::Unreachable),
+ };
+ let offset_neg =
+ id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64)));
+ result.push(Statement::Instruction(ast::Instruction::Neg(
+ ast::NegDetails {
+ typ: ast::ScalarType::S64,
+ flush_to_zero: None,
+ },
+ ast::Arg2 {
+ src: offset,
+ dst: offset_neg,
+ },
+ )));
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
+ state_space: ast::LdStateSpace::Global,
+ dst: *remapped_ids.get(&arg.dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: ast::Operand::Reg(offset_neg),
+ }))
+ }
+ Statement::Instruction(inst) => {
+ let mut post_statements = Vec::new();
+ let new_statement = inst.visit_variable(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ expected_type,
+ )
+ },
+ )?;
+ result.push(new_statement);
+ for s in post_statements {
+ result.push(s);
+ }
+ }
+ Statement::Call(call) => {
+ let mut post_statements = Vec::new();
+ let new_statement = call.visit_variable(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ expected_type,
+ )
+ },
+ )?;
+ result.push(new_statement);
+ for s in post_statements {
+ result.push(s);
+ }
+ }
+ _ => return Err(TranslateError::Unreachable),
+ }
+ }
+ for arg in func_args.input.iter_mut() {
+ if func_args_ptr.contains(&arg.name) {
+ arg.v_type = ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ );
+ }
+ }
+ Ok(result)
+}
+
+fn convert_to_stateful_memory_access_postprocess(
+ id_defs: &mut NumericIdResolver,
+ remapped_ids: &HashMap<spirv::Word, spirv::Word>,
+ func_args_ptr: &HashSet<spirv::Word>,
+ result: &mut Vec<TypedStatement>,
+ post_statements: &mut Vec<TypedStatement>,
+ arg_desc: ArgumentDescriptor<spirv::Word>,
+ expected_type: Option<&ast::Type>,
+) -> Result<spirv::Word, TranslateError> {
+ Ok(match remapped_ids.get(&arg_desc.op) {
+ Some(new_id) => {
+ // We skip conversion here to trigger PtrAcces in a later pass
+ let old_type = match expected_type {
+ Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
+ _ => id_defs.get_typed(arg_desc.op)?.0,
+ };
+ let old_type_clone = old_type.clone();
+ let converting_id = id_defs.new_non_variable(Some(old_type_clone));
+ if arg_desc.is_dst {
+ post_statements.push(Statement::Conversion(ImplicitConversion {
+ src: converting_id,
+ dst: *new_id,
+ from: old_type,
+ to: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: arg_desc.sema,
+ }));
+ converting_id
+ } else {
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ to: old_type,
+ kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ converting_id
+ }
+ }
+ None => match func_args_ptr.get(&arg_desc.op) {
+ Some(new_id) => {
+ if arg_desc.is_dst {
+ return Err(TranslateError::Unreachable);
+ }
+ // We skip conversion here to trigger PtrAcces in a later pass
+ let old_type = match expected_type {
+ Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
+ _ => id_defs.get_typed(arg_desc.op)?.0,
+ };
+ let old_type_clone = old_type.clone();
+ let converting_id = id_defs.new_non_variable(Some(old_type));
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global),
+ ast::LdStateSpace::Param,
+ ),
+ to: old_type_clone,
+ kind: ConversionKind::PtrToPtr { spirv_ptr: false },
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ converting_id
+ }
+ None => arg_desc.op,
+ },
+ })
+}
+
+fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
+ if !remapped_ids.contains_key(&arg.dst) {
+ return false;
+ }
+ match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => true,
+ Some(src2) if remapped_ids.contains_key(src2) => true,
+ _ => false,
+ }
+}
+
+fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
+ match id_defs.get_typed(id) {
+ Ok((ast::Type::Scalar(ast::ScalarType::U64), _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::S64), _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true,
+ _ => false,
+ }
+}
+
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister {
Tid,
@@ -4270,7 +4625,7 @@ impl PtxSpecialRegister {
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
- variables_type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
+ variables_type_check: HashMap<u32, Option<(ast::Type, bool)>>,
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
fns: HashMap<spirv::Word, FnDecl>,
}
@@ -4295,15 +4650,16 @@ impl<'a> GlobalStringIdResolver<'a> {
self.get_or_add_impl(id, None)
}
- fn get_or_add_def_typed(&mut self, id: &'a str, typ: (StateSpace, ast::Type)) -> spirv::Word {
- self.get_or_add_impl(id, Some(typ))
- }
-
- fn get_or_add_impl(
+ fn get_or_add_def_typed(
&mut self,
id: &'a str,
- typ: Option<(StateSpace, ast::Type)>,
+ typ: ast::Type,
+ is_variable: bool,
) -> spirv::Word {
+ self.get_or_add_impl(id, Some((typ, is_variable)))
+ }
+
+ fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word {
let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
@@ -4399,10 +4755,10 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
- global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
- type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
+ type_check: HashMap<u32, Option<(ast::Type, bool)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@@ -4452,13 +4808,14 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
}
}
- fn add_def(&mut self, id: &'a str, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
+ fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>, is_variable: bool) -> spirv::Word {
let numeric_id = *self.current_id;
self.variables
.last_mut()
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
- self.type_check.insert(numeric_id, typ);
+ self.type_check
+ .insert(numeric_id, typ.map(|t| (t, is_variable)));
*self.current_id += 1;
numeric_id
}
@@ -4468,8 +4825,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
&mut self,
base_id: &'a str,
count: u32,
- ss: StateSpace,
typ: ast::Type,
+ is_variable: bool,
) -> impl Iterator<Item = spirv::Word> {
let numeric_id = *self.current_id;
for i in 0..count {
@@ -4478,7 +4835,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
self.type_check
- .insert(numeric_id + i, Some((ss, typ.clone())));
+ .insert(numeric_id + i, Some((typ.clone(), is_variable)));
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -4487,8 +4844,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
- global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
- type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, bool)>>,
special_registers: HashMap<spirv::Word, PtxSpecialRegister>,
}
@@ -4497,23 +4854,32 @@ impl<'b> NumericIdResolver<'b> {
MutableNumericIdResolver { base: self }
}
- fn get_typed(&self, id: spirv::Word) -> Result<(StateSpace, ast::Type), TranslateError> {
+ fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(&id) {
- Some(x) => Ok((StateSpace::Reg, x.get_type())),
+ Some(x) => Ok((x.get_type(), true)),
None => match self.global_type_check.get(&id) {
- Some(Some(x)) => Ok(x.clone()),
+ Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
},
},
}
}
- fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
+ // This is for identifiers which will be emitted later as OpVariable
+ // They are candidates for insertion of LoadVar/StoreVar
+ fn new_variable(&mut self, typ: ast::Type) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, typ);
+ self.type_check.insert(new_id, Some((typ, true)));
+ *self.current_id += 1;
+ new_id
+ }
+
+ fn new_non_variable(&mut self, typ: Option<ast::Type>) -> spirv::Word {
+ let new_id = *self.current_id;
+ self.type_check.insert(new_id, typ.map(|t| (t, false)));
*self.current_id += 1;
new_id
}
@@ -4529,11 +4895,11 @@ impl<'b> MutableNumericIdResolver<'b> {
}
fn get_typed(&self, id: spirv::Word) -> Result<ast::Type, TranslateError> {
- self.base.get_typed(id).map(|(_, t)| t)
+ self.base.get_typed(id).map(|(t, _)| t)
}
- fn new_id(&mut self, typ: ast::Type) -> spirv::Word {
- self.base.new_id(Some((StateSpace::Reg, typ)))
+ fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word {
+ self.base.new_non_variable(Some(typ))
}
}
@@ -4541,101 +4907,102 @@ enum Statement<I, P: ast::ArgParams> {
Label(u32),
Variable(ast::Variable<ast::VariableType, P::Id>),
Instruction(I),
+ // SPIR-V compatible replacement for PTX predicates
+ Conditional(BrachCondition),
+ Call(ResolvedCall<P>),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
- Call(ResolvedCall<P>),
Composite(CompositeRead),
- // SPIR-V compatible replacement for PTX predicates
- Conditional(BrachCondition),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
Undef(ast::Type, spirv::Word),
- PtrAdd {
- underlying_type: ast::PointerType,
- state_space: ast::LdStateSpace,
- dst: spirv::Word,
- ptr_src: spirv::Word,
- constant_src: spirv::Word,
- },
+ PtrAccess(PtrAccess<P>),
}
impl ExpandedStatement {
- fn map_id(self, f: &mut impl FnMut(spirv::Word) -> spirv::Word) -> ExpandedStatement {
+ fn map_id(self, f: &mut impl FnMut(spirv::Word, bool) -> spirv::Word) -> ExpandedStatement {
match self {
- Statement::Label(id) => Statement::Label(f(id)),
+ Statement::Label(id) => Statement::Label(f(id, false)),
Statement::Variable(mut var) => {
- var.name = f(var.name);
+ var.name = f(var.name, true);
Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| Ok(f(arg.op)))
+ .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| {
+ Ok(f(arg.op, arg.is_dst))
+ })
.unwrap(),
Statement::LoadVar(mut arg, typ) => {
- arg.dst = f(arg.dst);
- arg.src = f(arg.src);
+ arg.dst = f(arg.dst, true);
+ arg.src = f(arg.src, false);
Statement::LoadVar(arg, typ)
}
Statement::StoreVar(mut arg, typ) => {
- arg.src1 = f(arg.src1);
- arg.src2 = f(arg.src2);
+ arg.src1 = f(arg.src1, false);
+ arg.src2 = f(arg.src2, false);
Statement::StoreVar(arg, typ)
}
Statement::Call(mut call) => {
- for (id, _) in call.ret_params.iter_mut() {
- *id = f(*id);
+ for (id, typ) in call.ret_params.iter_mut() {
+ let is_dst = match typ {
+ ast::FnArgumentType::Reg(_) => true,
+ ast::FnArgumentType::Param(_) => false,
+ ast::FnArgumentType::Shared => false,
+ };
+ *id = f(*id, is_dst);
}
- call.func = f(call.func);
+ call.func = f(call.func, false);
for (id, _) in call.param_list.iter_mut() {
- *id = f(*id);
+ *id = f(*id, false);
}
Statement::Call(call)
}
Statement::Composite(mut composite) => {
- composite.dst = f(composite.dst);
- composite.src_composite = f(composite.src_composite);
+ composite.dst = f(composite.dst, true);
+ composite.src_composite = f(composite.src_composite, false);
Statement::Composite(composite)
}
Statement::Conditional(mut conditional) => {
- conditional.predicate = f(conditional.predicate);
- conditional.if_true = f(conditional.if_true);
- conditional.if_false = f(conditional.if_false);
+ conditional.predicate = f(conditional.predicate, false);
+ conditional.if_true = f(conditional.if_true, false);
+ conditional.if_false = f(conditional.if_false, false);
Statement::Conditional(conditional)
}
Statement::Conversion(mut conv) => {
- conv.dst = f(conv.dst);
- conv.src = f(conv.src);
+ conv.dst = f(conv.dst, true);
+ conv.src = f(conv.src, false);
Statement::Conversion(conv)
}
Statement::Constant(mut constant) => {
- constant.dst = f(constant.dst);
+ constant.dst = f(constant.dst, true);
Statement::Constant(constant)
}
Statement::RetValue(data, id) => {
- let id = f(id);
+ let id = f(id, false);
Statement::RetValue(data, id)
}
Statement::Undef(typ, id) => {
- let id = f(id);
+ let id = f(id, true);
Statement::Undef(typ, id)
}
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
- constant_src,
- } => {
- let dst = f(dst);
- let ptr_src = f(ptr_src);
- let constant_src = f(constant_src);
- Statement::PtrAdd {
+ offset_src: constant_src,
+ }) => {
+ let dst = f(dst, true);
+ let ptr_src = f(ptr_src, false);
+ let constant_src = f(constant_src, false);
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
- constant_src,
- }
+ offset_src: constant_src,
+ })
}
}
}
@@ -4740,6 +5107,70 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
}
}
+impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
+ fn map<To: ArgParamsEx<Id = spirv::Word>, V: ArgumentMapVisitor<P, To>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<PtrAccess<To>, TranslateError> {
+ let sema = match self.state_space {
+ ast::LdStateSpace::Const
+ | ast::LdStateSpace::Global
+ | ast::LdStateSpace::Shared
+ | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
+ ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
+ ArgumentSemantics::RegisterPointer
+ }
+ };
+ let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space);
+ let new_dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema,
+ },
+ Some(&ptr_type),
+ )?;
+ let new_ptr_src = visitor.id(
+ ArgumentDescriptor {
+ op: self.ptr_src,
+ is_dst: false,
+ sema,
+ },
+ Some(&ptr_type),
+ )?;
+ let new_constant_src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.offset_src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(ast::ScalarType::S64),
+ )?;
+ Ok(PtrAccess {
+ underlying_type: self.underlying_type,
+ state_space: self.state_space,
+ dst: new_dst,
+ ptr_src: new_ptr_src,
+ offset_src: new_constant_src,
+ })
+ }
+}
+
+impl VisitVariable for PtrAccess<TypedArgParams> {
+ fn visit_variable<
+ 'a,
+ F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<TypedStatement, TranslateError> {
+ Ok(Statement::PtrAccess(self.map(f)?))
+ }
+}
+
pub trait ArgParamsEx: ast::ArgParams + Sized {
fn get_fn_decl<'x, 'b>(
id: &Self::Id,
@@ -5035,6 +5466,14 @@ pub struct ArgumentDescriptor<Op> {
sema: ArgumentSemantics,
}
+pub struct PtrAccess<P: ast::ArgParams> {
+ underlying_type: ast::PointerType,
+ state_space: ast::LdStateSpace,
+ dst: spirv::Word,
+ ptr_src: spirv::Word,
+ offset_src: P::Operand,
+}
+
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum ArgumentSemantics {
// normal register access
@@ -5266,6 +5705,56 @@ impl VisitVariable for ast::Instruction<TypedArgParams> {
fn visit_variable<
'a,
F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<TypedStatement, TranslateError> {
+ Ok(Statement::Instruction(self.map(f)?))
+ }
+}
+
+impl ImplicitConversion {
+ fn map<
+ T: ArgParamsEx<Id = spirv::Word>,
+ U: ArgParamsEx<Id = spirv::Word>,
+ V: ArgumentMapVisitor<T, U>,
+ >(
+ self,
+ visitor: &mut V,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ let new_dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: self.dst_sema,
+ },
+ Some(&self.to),
+ )?;
+ let new_src = visitor.id(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: self.src_sema,
+ },
+ Some(&self.from),
+ )?;
+ Ok(Statement::Conversion({
+ ImplicitConversion {
+ src: new_src,
+ dst: new_dst,
+ ..self
+ }
+ }))
+ }
+}
+
+impl VisitVariable for ImplicitConversion {
+ fn visit_variable<
+ 'a,
+ F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
@@ -5273,7 +5762,21 @@ impl VisitVariable for ast::Instruction<TypedArgParams> {
self,
f: &mut F,
) -> Result<TypedStatement, TranslateError> {
- Ok(Statement::Instruction(self.map(f)?))
+ self.map(f)
+ }
+}
+
+impl VisitVariableExpanded for ImplicitConversion {
+ fn visit_variable_extended<
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<ExpandedStatement, TranslateError> {
+ self.map(f)
}
}
@@ -5708,6 +6211,8 @@ struct ImplicitConversion {
from: ast::Type,
to: ast::Type,
kind: ConversionKind,
+ src_sema: ArgumentSemantics,
+ dst_sema: ArgumentSemantics,
}
#[derive(PartialEq, Copy, Clone)]