diff options
author | Andrzej Janik <[email protected]> | 2020-11-19 22:12:12 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-11-19 22:12:12 +0100 |
commit | f77b653d363a3b05d34d390874cec631ff948814 (patch) | |
tree | 5e9639b77647209ee79855c9b574cf457e875520 | |
parent | eac5fbd806639c42813d06095fd3911a4664538b (diff) | |
download | ZLUDA-f77b653d363a3b05d34d390874cec631ff948814.tar.gz ZLUDA-f77b653d363a3b05d34d390874cec631ff948814.zip |
Implement stateless-to-stateful optimization
-rw-r--r-- | ptx/src/ast.rs | 6 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 12 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/atom_inc.spvtxt | 170 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/cvta.spvtxt | 83 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/extern_shared_call.spvtxt | 84 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 18 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/reg_local.spvtxt | 12 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx | 31 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt | 89 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx | 35 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt | 93 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx | 35 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt | 105 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_simple.ptx | 25 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt | 65 | ||||
-rw-r--r-- | ptx/src/translate.rs | 1119 |
16 files changed, 1495 insertions, 487 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 5a5f6be..367f060 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -766,6 +766,8 @@ sub_type! { LdStType { Scalar(LdStScalarType), Vector(LdStScalarType, u8), + // Used in generated code + Pointer(PointerType, LdStateSpace), } } @@ -774,6 +776,10 @@ impl From<LdStType> for PointerType { match t { LdStType::Scalar(t) => PointerType::Scalar(t.into()), LdStType::Vector(t, len) => PointerType::Vector(t.into(), len), + LdStType::Pointer(PointerType::Scalar(scalar_type), space) => { + PointerType::Pointer(scalar_type, space) + } + LdStType::Pointer(..) => unreachable!(), } } } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 6c231b2..d2c235a 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -1237,18 +1237,18 @@ InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = { - "cvta" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => { + "cvta" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => { ast::Instruction::Cvta(ast::CvtaDetails { - to: to, - from: ast::CvtaStateSpace::Generic, + to: ast::CvtaStateSpace::Generic, + from, size: s }, a) }, - "cvta" ".to" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => { + "cvta" ".to" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => { ast::Instruction::Cvta(ast::CvtaDetails { - to: ast::CvtaStateSpace::Generic, - from: from, + to, + from: ast::CvtaStateSpace::Generic, size: s }, a) diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt index 6948cd9..fda26c5 100644 --- a/ptx/src/test/spirv_run/atom_inc.spvtxt +++ b/ptx/src/test/spirv_run/atom_inc.spvtxt @@ -1,89 +1,81 @@ -; SPIR-V -; Version: 1.3 -; Generator: rspirv -; Bound: 60 -OpCapability GenericPointer -OpCapability Linkage -OpCapability Addresses -OpCapability Kernel -OpCapability Int8 -OpCapability Int16 -OpCapability Int64 -OpCapability Float16 -OpCapability Float64 -; OpCapability FunctionFloatControlINTEL -; OpExtension "SPV_INTEL_float_controls2" -%49 = OpExtInstImport "OpenCL.std" -OpMemoryModel Physical64 OpenCL -OpEntryPoint Kernel %1 "atom_inc" -OpDecorate %40 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import -OpDecorate %44 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_global_inc" Import -%50 = OpTypeVoid -%51 = OpTypeInt 32 0 -%52 = OpTypePointer Generic %51 -%53 = OpTypeFunction %51 %52 %51 -%54 = OpTypePointer CrossWorkgroup %51 -%55 = OpTypeFunction %51 %54 %51 -%56 = OpTypeInt 64 0 -%57 = OpTypeFunction %50 %56 %56 -%58 = OpTypePointer Function %56 -%59 = OpTypePointer Function %51 -%27 = OpConstant %51 101 -%28 = OpConstant %51 101 -%29 = OpConstant %56 4 -%31 = OpConstant %56 8 -%40 = OpFunction %51 None %53 -%42 = OpFunctionParameter %52 -%43 = OpFunctionParameter %51 -OpFunctionEnd -%44 = OpFunction %51 None %55 -%46 = OpFunctionParameter %54 -%47 = OpFunctionParameter %51 -OpFunctionEnd -%1 = OpFunction %50 None %57 -%9 = OpFunctionParameter %56 -%10 = OpFunctionParameter %56 -%39 = OpLabel -%2 = OpVariable %58 Function -%3 = OpVariable %58 Function -%4 = OpVariable %58 Function -%5 = OpVariable %58 Function -%6 = OpVariable %59 Function -%7 = OpVariable %59 Function -%8 = OpVariable %59 Function -OpStore %2 %9 -OpStore %3 %10 -%12 = OpLoad %56 %2 -%11 = OpCopyObject %56 %12 -OpStore %4 %11 -%14 = OpLoad %56 %3 -%13 = OpCopyObject %56 %14 -OpStore %5 %13 -%16 = OpLoad %56 %4 -%33 = OpConvertUToPtr %52 %16 -%15 = OpFunctionCall %51 %40 %33 %27 -OpStore %6 %15 -%18 = OpLoad %56 %4 -%34 = OpConvertUToPtr %54 %18 -%17 = OpFunctionCall %51 %44 %34 %28 -OpStore %7 %17 -%20 = OpLoad %56 %4 -%35 = OpConvertUToPtr %52 %20 -%19 = OpLoad %51 %35 -OpStore %8 %19 -%21 = OpLoad %56 %5 -%22 = OpLoad %51 %6 -%36 = OpConvertUToPtr %52 %21 -OpStore %36 %22 -%23 = OpLoad %56 %5 -%24 = OpLoad %51 %7 -%30 = OpIAdd %56 %23 %29 -%37 = OpConvertUToPtr %52 %30 -OpStore %37 %24 -%25 = OpLoad %56 %5 -%26 = OpLoad %51 %8 -%32 = OpIAdd %56 %25 %31 -%38 = OpConvertUToPtr %52 %32 -OpStore %38 %26 -OpReturn -OpFunctionEnd
\ No newline at end of file + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %47 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "atom_inc" + OpDecorate %38 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import + OpDecorate %42 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_global_inc" Import + %void = OpTypeVoid + %uint = OpTypeInt 32 0 +%_ptr_Generic_uint = OpTypePointer Generic %uint + %51 = OpTypeFunction %uint %_ptr_Generic_uint %uint +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %53 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint + %ulong = OpTypeInt 64 0 + %55 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Function_uint = OpTypePointer Function %uint + %uint_101 = OpConstant %uint 101 + %uint_101_0 = OpConstant %uint 101 + %ulong_4 = OpConstant %ulong 4 + %ulong_8 = OpConstant %ulong 8 + %38 = OpFunction %uint None %51 + %40 = OpFunctionParameter %_ptr_Generic_uint + %41 = OpFunctionParameter %uint + OpFunctionEnd + %42 = OpFunction %uint None %53 + %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint + %45 = OpFunctionParameter %uint + OpFunctionEnd + %1 = OpFunction %void None %55 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %37 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + %8 = OpVariable %_ptr_Function_uint Function + OpStore %2 %9 + OpStore %3 %10 + %11 = OpLoad %ulong %2 + OpStore %4 %11 + %12 = OpLoad %ulong %3 + OpStore %5 %12 + %14 = OpLoad %ulong %4 + %31 = OpConvertUToPtr %_ptr_Generic_uint %14 + %13 = OpFunctionCall %uint %38 %31 %uint_101 + OpStore %6 %13 + %16 = OpLoad %ulong %4 + %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16 + %15 = OpFunctionCall %uint %42 %32 %uint_101_0 + OpStore %7 %15 + %18 = OpLoad %ulong %4 + %33 = OpConvertUToPtr %_ptr_Generic_uint %18 + %17 = OpLoad %uint %33 + OpStore %8 %17 + %19 = OpLoad %ulong %5 + %20 = OpLoad %uint %6 + %34 = OpConvertUToPtr %_ptr_Generic_uint %19 + OpStore %34 %20 + %21 = OpLoad %ulong %5 + %22 = OpLoad %uint %7 + %28 = OpIAdd %ulong %21 %ulong_4 + %35 = OpConvertUToPtr %_ptr_Generic_uint %28 + OpStore %35 %22 + %23 = OpLoad %ulong %5 + %24 = OpLoad %uint %8 + %30 = OpIAdd %ulong %23 %ulong_8 + %36 = OpConvertUToPtr %_ptr_Generic_uint %30 + OpStore %36 %24 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt index cf6ff8b..143d0a5 100644 --- a/ptx/src/test/spirv_run/cvta.spvtxt +++ b/ptx/src/test/spirv_run/cvta.spvtxt @@ -7,48 +7,59 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %27 = OpExtInstImport "OpenCL.std" + %37 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvta" %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %30 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %41 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %float = OpTypeFloat 32 %_ptr_Function_float = OpTypePointer Function %float + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %1 = OpFunction %void None %30 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %25 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function + %1 = OpFunction %void None %41 + %17 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %18 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %35 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %7 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %8 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 - OpStore %4 %9 - %10 = OpLoad %ulong %3 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %20 = OpCopyObject %ulong %12 - %19 = OpCopyObject %ulong %20 - %11 = OpCopyObject %ulong %19 - OpStore %4 %11 - %14 = OpLoad %ulong %5 - %22 = OpCopyObject %ulong %14 - %21 = OpCopyObject %ulong %22 - %13 = OpCopyObject %ulong %21 - OpStore %5 %13 - %16 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16 - %15 = OpLoad %float %23 - OpStore %6 %15 - %17 = OpLoad %ulong %5 - %18 = OpLoad %float %6 - %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17 - OpStore %24 %18 + OpStore %2 %17 + OpStore %3 %18 + %10 = OpBitcast %_ptr_Function_ulong %2 + %9 = OpLoad %ulong %10 + %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %9 + OpStore %7 %19 + %12 = OpBitcast %_ptr_Function_ulong %3 + %11 = OpLoad %ulong %12 + %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %11 + OpStore %8 %20 + %21 = OpLoad %_ptr_CrossWorkgroup_uchar %7 + %14 = OpConvertPtrToU %ulong %21 + %30 = OpCopyObject %ulong %14 + %29 = OpCopyObject %ulong %30 + %13 = OpCopyObject %ulong %29 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13 + OpStore %7 %22 + %23 = OpLoad %_ptr_CrossWorkgroup_uchar %8 + %16 = OpConvertPtrToU %ulong %23 + %32 = OpCopyObject %ulong %16 + %31 = OpCopyObject %ulong %32 + %15 = OpCopyObject %ulong %31 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 + OpStore %8 %24 + %26 = OpLoad %_ptr_CrossWorkgroup_uchar %7 + %33 = OpBitcast %_ptr_CrossWorkgroup_float %26 + %25 = OpLoad %float %33 + OpStore %6 %25 + %27 = OpLoad %_ptr_CrossWorkgroup_uchar %8 + %28 = OpLoad %float %6 + %34 = OpBitcast %_ptr_CrossWorkgroup_float %27 + OpStore %34 %28 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt index d979193..39f8683 100644 --- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt @@ -7,7 +7,7 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %48 = OpExtInstImport "OpenCL.std" + %46 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %14 "extern_shared_call" %1 OpDecorate %1 Alignment 4 @@ -18,78 +18,76 @@ %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %55 = OpTypeFunction %void %_ptr_Workgroup_uchar + %53 = OpTypeFunction %void %_ptr_Workgroup_uchar %_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong %ulong_2 = OpConstant %ulong 2 - %62 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar + %60 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %2 = OpFunction %void None %55 - %40 = OpFunctionParameter %_ptr_Workgroup_uchar - %56 = OpLabel - %41 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %2 = OpFunction %void None %53 + %38 = OpFunctionParameter %_ptr_Workgroup_uchar + %54 = OpLabel + %39 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function %3 = OpVariable %_ptr_Function_ulong Function - OpStore %41 %40 + OpStore %39 %38 OpBranch %13 %13 = OpLabel - %42 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %41 - %5 = OpLoad %_ptr_Workgroup_uint %42 + %40 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 + %5 = OpLoad %_ptr_Workgroup_uint %40 %11 = OpBitcast %_ptr_Workgroup_ulong %5 %4 = OpLoad %ulong %11 OpStore %3 %4 %7 = OpLoad %ulong %3 %6 = OpIAdd %ulong %7 %ulong_2 OpStore %3 %6 - %43 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %41 - %8 = OpLoad %_ptr_Workgroup_uint %43 + %41 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 + %8 = OpLoad %_ptr_Workgroup_uint %41 %9 = OpLoad %ulong %3 %12 = OpBitcast %_ptr_Workgroup_ulong %8 OpStore %12 %9 OpReturn OpFunctionEnd - %14 = OpFunction %void None %62 + %14 = OpFunction %void None %60 %20 = OpFunctionParameter %ulong %21 = OpFunctionParameter %ulong - %44 = OpFunctionParameter %_ptr_Workgroup_uchar - %63 = OpLabel - %45 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %42 = OpFunctionParameter %_ptr_Workgroup_uchar + %61 = OpLabel + %43 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function %15 = OpVariable %_ptr_Function_ulong Function %16 = OpVariable %_ptr_Function_ulong Function %17 = OpVariable %_ptr_Function_ulong Function %18 = OpVariable %_ptr_Function_ulong Function %19 = OpVariable %_ptr_Function_ulong Function - OpStore %45 %44 - OpBranch %38 - %38 = OpLabel + OpStore %43 %42 + OpBranch %36 + %36 = OpLabel OpStore %15 %20 OpStore %16 %21 - %23 = OpLoad %ulong %15 - %22 = OpCopyObject %ulong %23 + %22 = OpLoad %ulong %15 OpStore %17 %22 - %25 = OpLoad %ulong %16 - %24 = OpCopyObject %ulong %25 - OpStore %18 %24 - %27 = OpLoad %ulong %17 - %34 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %27 - %26 = OpLoad %ulong %34 - OpStore %19 %26 - %46 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %45 - %28 = OpLoad %_ptr_Workgroup_uint %46 - %29 = OpLoad %ulong %19 - %35 = OpBitcast %_ptr_Workgroup_ulong %28 - OpStore %35 %29 - %65 = OpFunctionCall %void %2 %44 - %47 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %45 - %31 = OpLoad %_ptr_Workgroup_uint %47 - %36 = OpBitcast %_ptr_Workgroup_ulong %31 - %30 = OpLoad %ulong %36 - OpStore %19 %30 - %32 = OpLoad %ulong %18 - %33 = OpLoad %ulong %19 - %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %32 - OpStore %37 %33 + %23 = OpLoad %ulong %16 + OpStore %18 %23 + %25 = OpLoad %ulong %17 + %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %25 + %24 = OpLoad %ulong %32 + OpStore %19 %24 + %44 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 + %26 = OpLoad %_ptr_Workgroup_uint %44 + %27 = OpLoad %ulong %19 + %33 = OpBitcast %_ptr_Workgroup_ulong %26 + OpStore %33 %27 + %63 = OpFunctionCall %void %2 %42 + %45 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 + %29 = OpLoad %_ptr_Workgroup_uint %45 + %34 = OpBitcast %_ptr_Workgroup_ulong %29 + %28 = OpLoad %ulong %34 + OpStore %19 %28 + %30 = OpLoad %ulong %18 + %31 = OpLoad %ulong %19 + %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %30 + OpStore %35 %31 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index bd74508..f18b15c 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -133,6 +133,10 @@ test_ptx!( [0b11111000_11000001_00100010_10100000u32, 16u32, 8u32],
[0b11000001u32]
);
+test_ptx!(stateful_ld_st_simple, [121u64], [121u64]);
+test_ptx!(stateful_ld_st_ntid, [123u64], [123u64]);
+test_ptx!(stateful_ld_st_ntid_chain, [12651u64], [12651u64]);
+test_ptx!(stateful_ld_st_ntid_sub, [96311u64], [96311u64]);
struct DisplayError<T: Debug> {
err: T,
@@ -292,7 +296,7 @@ fn test_spvtxt_assert<'a>( rspirv::binary::parse_words(&parsed_spirv, &mut loader)?;
let spvtxt_mod = loader.module();
unsafe { spirv_tools::spvBinaryDestroy(spv_binary) };
- if !is_spirv_fn_equal(&spirv_module.spirv.functions[0], &spvtxt_mod.functions[0]) {
+ if !is_spirv_fns_equal(&spirv_module.spirv.functions, &spvtxt_mod.functions) {
// We could simply use ptx_mod.disassemble, but SPIRV-Tools text formattinmg is so much nicer
let spv_from_ptx_binary = spirv_module.spirv.assemble();
let mut spv_text: spirv_tools::spv_text = ptr::null_mut();
@@ -364,6 +368,18 @@ impl<T: Copy + Eq + Hash> EqMap<T> { }
}
+fn is_spirv_fns_equal(fns1: &[Function], fns2: &[Function]) -> bool {
+ if fns1.len() != fns2.len() {
+ return false;
+ }
+ for (fn1, fn2) in fns1.iter().zip(fns2.iter()) {
+ if !is_spirv_fn_equal(fn1, fn2) {
+ return false;
+ }
+ }
+ true
+}
+
fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool {
let mut map = EqMap::new();
if !is_option_equal(&fn1.def, &fn2.def, &mut map, is_instr_equal) {
diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt index 596cedc..5ce3689 100644 --- a/ptx/src/test/spirv_run/reg_local.spvtxt +++ b/ptx/src/test/spirv_run/reg_local.spvtxt @@ -22,7 +22,9 @@ %_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8 %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %ulong_1 = OpConstant %ulong 1 +%_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_0 = OpConstant %ulong 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_0_0 = OpConstant %ulong 0 %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong @@ -48,12 +50,12 @@ %14 = OpLoad %ulong %7 %26 = OpCopyObject %ulong %14 %19 = OpIAdd %ulong %26 %ulong_1 - %27 = OpBitcast %_ptr_Function_ulong %4 + %27 = OpBitcast %_ptr_Generic_ulong %4 OpStore %27 %19 - %28 = OpBitcast %_ptr_Function_ulong %4 - %45 = OpBitcast %ulong %28 - %46 = OpIAdd %ulong %45 %ulong_0 - %21 = OpBitcast %_ptr_Function_ulong %46 + %28 = OpBitcast %_ptr_Generic_ulong %4 + %47 = OpBitcast %_ptr_Generic_uchar %28 + %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0 + %21 = OpBitcast %_ptr_Generic_ulong %48 %29 = OpLoad %ulong %21 %15 = OpCopyObject %ulong %29 OpStore %7 %15 diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx b/ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx new file mode 100644 index 0000000..1fc37d1 --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.ptx @@ -0,0 +1,31 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry stateful_ld_st_ntid(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .b64 in_addr;
+ .reg .b64 out_addr;
+ .reg .u32 tid_32;
+ .reg .u64 tid_64;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ cvta.to.global.u64 in_addr, in_addr;
+ cvta.to.global.u64 out_addr, out_addr;
+
+ mov.u32 tid_32, %tid.x;
+ cvt.u64.u32 tid_64, tid_32;
+
+ add.u64 in_addr, in_addr, tid_64;
+ add.u64 out_addr, out_addr, tid_64;
+
+ ld.global.u64 temp, [in_addr];
+ st.global.u64 [out_addr], temp;
+ ret;
+}
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt new file mode 100644 index 0000000..c53ad51 --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt @@ -0,0 +1,89 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %49 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID + OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v4uint = OpTypeVector %uint 4 +%_ptr_Input_v4uint = OpTypePointer Input %v4uint +%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %56 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar +%_ptr_Function_uint = OpTypePointer Function %uint + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %1 = OpFunction %void None %56 + %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %47 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %20 + OpStore %3 %21 + %13 = OpBitcast %_ptr_Function_ulong %2 + %43 = OpLoad %ulong %13 + %12 = OpCopyObject %ulong %43 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %12 + OpStore %10 %22 + %15 = OpBitcast %_ptr_Function_ulong %3 + %44 = OpLoad %ulong %15 + %14 = OpCopyObject %ulong %44 + %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 + OpStore %11 %23 + %24 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %17 = OpConvertPtrToU %ulong %24 + %16 = OpCopyObject %ulong %17 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %16 + OpStore %10 %25 + %26 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %19 = OpConvertPtrToU %ulong %26 + %18 = OpCopyObject %ulong %19 + %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18 + OpStore %11 %27 + %29 = OpLoad %v4uint %gl_LocalInvocationID + %42 = OpCompositeExtract %uint %29 0 + %28 = OpCopyObject %uint %42 + OpStore %6 %28 + %31 = OpLoad %uint %6 + %61 = OpBitcast %uint %31 + %30 = OpUConvert %ulong %61 + OpStore %7 %30 + %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %34 = OpLoad %ulong %7 + %62 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 + %63 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %62 %34 + %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %63 + OpStore %10 %32 + %36 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %37 = OpLoad %ulong %7 + %64 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 + %65 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %64 %37 + %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %65 + OpStore %11 %35 + %39 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %45 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 + %38 = OpLoad %ulong %45 + OpStore %8 %38 + %40 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %41 = OpLoad %ulong %8 + %46 = OpBitcast %_ptr_CrossWorkgroup_ulong %40 + OpStore %46 %41 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx new file mode 100644 index 0000000..ef7645d --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.ptx @@ -0,0 +1,35 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry stateful_ld_st_ntid_chain(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .b64 in_addr1;
+ .reg .b64 in_addr2;
+ .reg .b64 in_addr3;
+ .reg .b64 out_addr1;
+ .reg .b64 out_addr2;
+ .reg .b64 out_addr3;
+ .reg .u32 tid_32;
+ .reg .u64 tid_64;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr1, [input];
+ ld.param.u64 out_addr1, [output];
+
+ cvta.to.global.u64 in_addr2, in_addr1;
+ cvta.to.global.u64 out_addr2, out_addr1;
+
+ mov.u32 tid_32, %tid.x;
+ cvt.u64.u32 tid_64, tid_32;
+
+ add.u64 in_addr3, in_addr2, tid_64;
+ add.u64 out_addr3, out_addr2, tid_64;
+
+ ld.global.u64 temp, [in_addr3];
+ st.global.u64 [out_addr3], temp;
+ ret;
+}
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt new file mode 100644 index 0000000..5ba889c --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt @@ -0,0 +1,93 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %57 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" %gl_LocalInvocationID + OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v4uint = OpTypeVector %uint 4 +%_ptr_Input_v4uint = OpTypePointer Input %v4uint +%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %64 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar +%_ptr_Function_uint = OpTypePointer Function %uint + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %1 = OpFunction %void None %64 + %28 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %55 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function_uint Function + %11 = OpVariable %_ptr_Function_ulong Function + %12 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %28 + OpStore %3 %29 + %21 = OpBitcast %_ptr_Function_ulong %2 + %51 = OpLoad %ulong %21 + %20 = OpCopyObject %ulong %51 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20 + OpStore %14 %30 + %23 = OpBitcast %_ptr_Function_ulong %3 + %52 = OpLoad %ulong %23 + %22 = OpCopyObject %ulong %52 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 + OpStore %17 %31 + %32 = OpLoad %_ptr_CrossWorkgroup_uchar %14 + %25 = OpConvertPtrToU %ulong %32 + %24 = OpCopyObject %ulong %25 + %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24 + OpStore %15 %33 + %34 = OpLoad %_ptr_CrossWorkgroup_uchar %17 + %27 = OpConvertPtrToU %ulong %34 + %26 = OpCopyObject %ulong %27 + %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 + OpStore %18 %35 + %37 = OpLoad %v4uint %gl_LocalInvocationID + %50 = OpCompositeExtract %uint %37 0 + %36 = OpCopyObject %uint %50 + OpStore %10 %36 + %39 = OpLoad %uint %10 + %69 = OpBitcast %uint %39 + %38 = OpUConvert %ulong %69 + OpStore %11 %38 + %41 = OpLoad %_ptr_CrossWorkgroup_uchar %15 + %42 = OpLoad %ulong %11 + %70 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 + %71 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %70 %42 + %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %71 + OpStore %16 %40 + %44 = OpLoad %_ptr_CrossWorkgroup_uchar %18 + %45 = OpLoad %ulong %11 + %72 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 + %73 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %72 %45 + %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %73 + OpStore %19 %43 + %47 = OpLoad %_ptr_CrossWorkgroup_uchar %16 + %53 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 + %46 = OpLoad %ulong %53 + OpStore %12 %46 + %48 = OpLoad %_ptr_CrossWorkgroup_uchar %19 + %49 = OpLoad %ulong %12 + %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %48 + OpStore %54 %49 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx new file mode 100644 index 0000000..018918c --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.ptx @@ -0,0 +1,35 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry stateful_ld_st_ntid_sub(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .b64 in_addr1;
+ .reg .b64 in_addr2;
+ .reg .b64 in_addr3;
+ .reg .b64 out_addr1;
+ .reg .b64 out_addr2;
+ .reg .b64 out_addr3;
+ .reg .u32 tid_32;
+ .reg .u64 tid_64;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr1, [input];
+ ld.param.u64 out_addr1, [output];
+
+ cvta.to.global.u64 in_addr2, in_addr1;
+ cvta.to.global.u64 out_addr2, out_addr1;
+
+ mov.u32 tid_32, %tid.x;
+ cvt.u64.u32 tid_64, tid_32;
+
+ sub.s64 in_addr3, in_addr2, tid_64;
+ sub.s64 out_addr3, out_addr2, tid_64;
+
+ ld.global.u64 temp, [in_addr3+-0];
+ st.global.u64 [out_addr3+-0], temp;
+ ret;
+}
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt new file mode 100644 index 0000000..3c215d4 --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_sub.spvtxt @@ -0,0 +1,105 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %65 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "stateful_ld_st_ntid_sub" %gl_LocalInvocationID + OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v4uint = OpTypeVector %uint 4 +%_ptr_Input_v4uint = OpTypePointer Input %v4uint +%gl_LocalInvocationID = OpVariable %_ptr_Input_v4uint Input + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %72 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar +%_ptr_Function_uint = OpTypePointer Function %uint + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong + %ulong_0 = OpConstant %ulong 0 +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %ulong_0_0 = OpConstant %ulong 0 + %1 = OpFunction %void None %72 + %30 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %31 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %63 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %17 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %18 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %19 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function_uint Function + %11 = OpVariable %_ptr_Function_ulong Function + %12 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %30 + OpStore %3 %31 + %21 = OpBitcast %_ptr_Function_ulong %2 + %57 = OpLoad %ulong %21 + %20 = OpCopyObject %ulong %57 + %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20 + OpStore %14 %32 + %23 = OpBitcast %_ptr_Function_ulong %3 + %58 = OpLoad %ulong %23 + %22 = OpCopyObject %ulong %58 + %33 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 + OpStore %17 %33 + %34 = OpLoad %_ptr_CrossWorkgroup_uchar %14 + %25 = OpConvertPtrToU %ulong %34 + %24 = OpCopyObject %ulong %25 + %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %24 + OpStore %15 %35 + %36 = OpLoad %_ptr_CrossWorkgroup_uchar %17 + %27 = OpConvertPtrToU %ulong %36 + %26 = OpCopyObject %ulong %27 + %37 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 + OpStore %18 %37 + %39 = OpLoad %v4uint %gl_LocalInvocationID + %52 = OpCompositeExtract %uint %39 0 + %38 = OpCopyObject %uint %52 + OpStore %10 %38 + %41 = OpLoad %uint %10 + %77 = OpBitcast %uint %41 + %40 = OpUConvert %ulong %77 + OpStore %11 %40 + %42 = OpLoad %ulong %11 + %59 = OpCopyObject %ulong %42 + %28 = OpSNegate %ulong %59 + %44 = OpLoad %_ptr_CrossWorkgroup_uchar %15 + %78 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 + %79 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %78 %28 + %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %79 + OpStore %16 %43 + %45 = OpLoad %ulong %11 + %60 = OpCopyObject %ulong %45 + %29 = OpSNegate %ulong %60 + %47 = OpLoad %_ptr_CrossWorkgroup_uchar %18 + %80 = OpBitcast %_ptr_CrossWorkgroup_uchar %47 + %81 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %80 %29 + %46 = OpBitcast %_ptr_CrossWorkgroup_uchar %81 + OpStore %19 %46 + %49 = OpLoad %_ptr_CrossWorkgroup_uchar %16 + %61 = OpBitcast %_ptr_CrossWorkgroup_ulong %49 + %83 = OpBitcast %_ptr_CrossWorkgroup_uchar %61 + %84 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %83 %ulong_0 + %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %84 + %48 = OpLoad %ulong %54 + OpStore %12 %48 + %50 = OpLoad %_ptr_CrossWorkgroup_uchar %19 + %51 = OpLoad %ulong %12 + %62 = OpBitcast %_ptr_CrossWorkgroup_ulong %50 + %85 = OpBitcast %_ptr_CrossWorkgroup_uchar %62 + %86 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %85 %ulong_0_0 + %56 = OpBitcast %_ptr_CrossWorkgroup_ulong %86 + OpStore %56 %51 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx new file mode 100644 index 0000000..5650ada --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx @@ -0,0 +1,25 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry stateful_ld_st_simple(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 in_addr2;
+ .reg .u64 out_addr2;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ cvta.to.global.u64 in_addr2, in_addr;
+ cvta.to.global.u64 out_addr2, out_addr;
+
+ ld.global.u64 temp, [in_addr2];
+ st.global.u64 [out_addr2], temp;
+ ret;
+}
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt new file mode 100644 index 0000000..cfd87eb --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt @@ -0,0 +1,65 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %41 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "stateful_ld_st_simple" + %void = OpTypeVoid + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %45 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %1 = OpFunction %void None %45 + %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %22 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %39 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %9 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %8 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %21 + OpStore %3 %22 + %14 = OpBitcast %_ptr_Function_ulong %2 + %13 = OpLoad %ulong %14 + %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13 + OpStore %9 %23 + %16 = OpBitcast %_ptr_Function_ulong %3 + %15 = OpLoad %ulong %16 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 + OpStore %10 %24 + %25 = OpLoad %_ptr_CrossWorkgroup_uchar %9 + %18 = OpConvertPtrToU %ulong %25 + %34 = OpCopyObject %ulong %18 + %33 = OpCopyObject %ulong %34 + %17 = OpCopyObject %ulong %33 + %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %17 + OpStore %11 %26 + %27 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %20 = OpConvertPtrToU %ulong %27 + %36 = OpCopyObject %ulong %20 + %35 = OpCopyObject %ulong %36 + %19 = OpCopyObject %ulong %35 + %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19 + OpStore %12 %28 + %30 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %37 = OpBitcast %_ptr_CrossWorkgroup_ulong %30 + %29 = OpLoad %ulong %37 + OpStore %8 %29 + %31 = OpLoad %_ptr_CrossWorkgroup_uchar %12 + %32 = OpLoad %ulong %8 + %38 = OpBitcast %_ptr_CrossWorkgroup_ulong %31 + OpStore %38 %32 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index f0a3187..328bf30 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast;
use half::f16;
use rspirv::dr;
-use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
+use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
@@ -72,15 +72,13 @@ impl From<ast::PointerType> for ast::Type { }
impl ast::Type {
- fn pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
+ fn param_pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
Ok(match self {
ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
ast::Type::Vector(t, len) => {
ast::Type::Pointer(ast::PointerType::Vector(t, len), space)
}
- ast::Type::Array(t, dims) => {
- ast::Type::Pointer(ast::PointerType::Array(t, dims), space)
- }
+ ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
}
@@ -726,7 +724,7 @@ fn convert_dynamic_shared_memory_usage<'input>( multi_hash_map_append(&mut directly_called_by, call.func, call_key);
Statement::Call(call)
}
- statement => statement.map_id(&mut |id| {
+ statement => statement.map_id(&mut |id, _| {
if extern_shared_decls.contains_key(&id) {
methods_using_extern_shared.insert(call_key);
}
@@ -843,7 +841,7 @@ fn replace_uses_of_shared_memory<'a>( result.push(Statement::Call(call))
}
statement => {
- let new_statement = statement.map_id(&mut |id| {
+ let new_statement = statement.map_id(&mut |id, _| {
if let Some(typ) = extern_shared_decls.get(&id) {
let replacement_id = new_id();
if *typ != ast::SizedScalarType::B8 {
@@ -859,6 +857,8 @@ fn replace_uses_of_shared_memory<'a>( ast::LdStateSpace::Shared,
),
kind: ConversionKind::PtrToPtr { spirv_ptr: true },
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
}));
}
replacement_id
@@ -971,7 +971,7 @@ fn compute_denorm_information<'input>( Statement::Undef(_, _) => {}
Statement::Label(_) => {}
Statement::Variable(_) => {}
- Statement::PtrAdd { .. } => {}
+ Statement::PtrAccess { .. } => {}
}
}
denorm_methods.insert(method_key, flush_counter);
@@ -1022,15 +1022,10 @@ fn emit_builtins( builder,
SpirvType::Pointer(
Box::new(SpirvType::from(reg.get_type())),
- spirv::StorageClass::UniformConstant,
+ spirv::StorageClass::Input,
),
);
- builder.variable(
- result_type,
- Some(*id),
- spirv::StorageClass::UniformConstant,
- None,
- );
+ builder.variable(result_type, Some(*id), spirv::StorageClass::Input, None);
builder.decorate(
*id,
spirv::Decoration::BuiltIn,
@@ -1192,11 +1187,31 @@ fn translate_variable<'a>( id_defs: &mut GlobalStringIdResolver<'a>,
var: ast::Variable<ast::VariableType, &'a str>,
) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
- let (state_space, typ) = var.v_type.to_type();
+ let (space, var_type) = var.v_type.to_type();
+ let mut is_variable = false;
+ let var_type = match space {
+ ast::StateSpace::Reg => {
+ is_variable = true;
+ var_type
+ }
+ ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?,
+ ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
+ ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
+ ast::StateSpace::Shared => {
+ // If it's a pointer it will be translated to a method parameter later
+ if let ast::Type::Pointer(..) = var_type {
+ is_variable = true;
+ var_type
+ } else {
+ var_type.param_pointer_to(ast::LdStateSpace::Shared)?
+ }
+ }
+ ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
+ };
Ok(ast::Variable {
align: var.align,
v_type: var.v_type,
- name: id_defs.get_or_add_def_typed(var.name, (state_space.into(), typ)),
+ name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
array_init: var.array_init,
})
}
@@ -1218,10 +1233,8 @@ fn expand_kernel_params<'a, 'b>( Ok(ast::KernelArgument {
name: fn_resolver.add_def(
a.name,
- Some((
- StateSpace::Param,
- ast::Type::from(a.v_type.clone()).pointer_to(ast::LdStateSpace::Param)?,
- )),
+ Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?),
+ false,
),
v_type: a.v_type.clone(),
align: a.align,
@@ -1236,14 +1249,13 @@ fn expand_fn_params<'a, 'b>( args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
args.map(|a| {
- let var_type = a.v_type.to_func_type();
- let ss = match a.v_type {
- ast::FnArgumentType::Reg(_) => StateSpace::Reg,
- ast::FnArgumentType::Param(_) => StateSpace::Param,
- ast::FnArgumentType::Shared => StateSpace::Shared,
+ let is_variable = match a.v_type {
+ ast::FnArgumentType::Reg(_) => true,
+ _ => false,
};
+ let var_type = a.v_type.to_func_type();
Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some((ss, var_type))),
+ name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
v_type: a.v_type.clone(),
align: a.align,
array_init: Vec::new(),
@@ -1274,12 +1286,18 @@ fn to_ssa<'input, 'b>( };
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
let mut numeric_id_defs = id_defs.finish();
- let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
+ let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
+ let typed_statements =
+ convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
+ let ssa_statements = insert_mem_ssa_statements(
+ typed_statements,
+ &mut numeric_id_defs,
+ &f_args,
+ &mut spirv_decl,
+ )?;
let mut numeric_id_defs = numeric_id_defs.finish();
- let ssa_statements =
- insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?;
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
@@ -1421,62 +1439,23 @@ fn convert_to_typed_statements( };
result.push(Statement::Call(resolved_call));
}
- // Supported ld/st:
- // global: only compatible with reg b64/u64/s64 source/dest
- // generic: compatible with global/local sources
- // param: compiled as mov
- // local compiled as mov
- // We would like to convert ld/st local/param to movs here,
- // but they have different semantics for implicit conversions
- // For now, we convert generic ld from local params to ld.local.
- // This way, we can rely on further stages of the compilation on
- // ld.generic & ld.global having bytes address source
- // One complication: immediate address is only allowed in local,
- // It is not supported in generic ld
- // ld.local foo, [1];
- ast::Instruction::Ld(mut d, arg) => {
- match arg.src.underlying() {
- None => {}
- Some(u) => {
- let (ss, _) = id_defs.get_typed(*u)?;
- match (d.state_space, ss) {
- (ast::LdStateSpace::Generic, StateSpace::Local) => {
- d.state_space = ast::LdStateSpace::Local;
- }
- _ => {}
- };
- }
- };
+ ast::Instruction::Ld(d, arg) => {
result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast())));
}
- ast::Instruction::St(mut d, arg) => {
- match arg.src1.underlying() {
- None => {}
- Some(u) => {
- let (ss, _) = id_defs.get_typed(*u)?;
- match (d.state_space, ss) {
- (ast::StStateSpace::Generic, StateSpace::Local) => {
- d.state_space = ast::StStateSpace::Local;
- }
- _ => (),
- };
- }
- };
+ ast::Instruction::St(d, arg) => {
result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast())));
}
ast::Instruction::Mov(mut d, args) => match args {
ast::Arg2Mov::Normal(arg) => {
if let Some(src_id) = arg.src.single_underlying() {
- let (scope, _) = id_defs.get_typed(*src_id)?;
- d.src_is_address = match scope {
- StateSpace::Reg => false,
- StateSpace::Const
- | StateSpace::Global
- | StateSpace::Local
- | StateSpace::Shared
- | StateSpace::Param
- | StateSpace::ParamReg => true,
+ let (typ, _) = id_defs.get_typed(*src_id)?;
+ let take_address = match typ {
+ ast::Type::Scalar(_) => false,
+ ast::Type::Vector(_, _) => false,
+ ast::Type::Array(_, _) => true,
+ ast::Type::Pointer(_, _) => true,
};
+ d.src_is_address = take_address;
}
result.push(Statement::Instruction(ast::Instruction::Mov(
d,
@@ -1486,7 +1465,7 @@ fn convert_to_typed_statements( ast::Arg2Mov::Member(args) => {
if let Some(dst_typ) = args.vector_dst() {
match id_defs.get_typed(*dst_typ)? {
- (_, ast::Type::Vector(_, len)) => {
+ (ast::Type::Vector(_, len), _) => {
d.dst_width = len;
}
_ => return Err(TranslateError::MismatchedType),
@@ -1494,7 +1473,7 @@ fn convert_to_typed_statements( };
if let Some((src_typ, _)) = args.vector_src() {
match id_defs.get_typed(*src_typ)? {
- (_, ast::Type::Vector(_, len)) => {
+ (ast::Type::Vector(_, len), _) => {
d.src_width = len;
}
_ => return Err(TranslateError::MismatchedType),
@@ -1650,17 +1629,8 @@ fn convert_to_typed_statements( },
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
- Statement::LoadVar(a, t) => result.push(Statement::LoadVar(a, t)),
- Statement::StoreVar(a, t) => result.push(Statement::StoreVar(a, t)),
- Statement::Call(c) => result.push(Statement::Call(c.cast())),
- Statement::Composite(c) => result.push(Statement::Composite(c)),
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
- Statement::Conversion(c) => result.push(Statement::Conversion(c)),
- Statement::Constant(c) => result.push(Statement::Constant(c)),
- Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Undef(_, _) | Statement::PtrAdd { .. } => {
- return Err(TranslateError::Unreachable)
- }
+ _ => return Err(TranslateError::Unreachable),
}
}
Ok(result)
@@ -1689,14 +1659,14 @@ fn to_ptx_impl_atomic_call( };
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_id(None);
+ let fn_id = id_defs.new_non_variable(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
fn_id,
@@ -1707,7 +1677,7 @@ fn to_ptx_impl_atomic_call( ast::SizedScalarType::U32,
ptr_space,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
@@ -1715,7 +1685,7 @@ fn to_ptx_impl_atomic_call( v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
],
@@ -1779,12 +1749,12 @@ fn to_ptx_impl_bfe_call( let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_id(None);
+ let fn_id = id_defs.new_non_variable(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
fn_id,
@@ -1792,7 +1762,7 @@ fn to_ptx_impl_bfe_call( ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
@@ -1800,7 +1770,7 @@ fn to_ptx_impl_bfe_call( v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
@@ -1808,7 +1778,7 @@ fn to_ptx_impl_bfe_call( v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
ast::ScalarType::U32,
)),
- name: id_defs.new_id(None),
+ name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
],
@@ -1893,10 +1863,10 @@ fn normalize_labels( | Statement::Constant(_)
| Statement::Label(_)
| Statement::Undef(_, _)
- | Statement::PtrAdd { .. } => {}
+ | Statement::PtrAccess { .. } => {}
}
}
- iter::once(Statement::Label(id_def.new_id(None)))
+ iter::once(Statement::Label(id_def.new_non_variable(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
@@ -1907,15 +1877,15 @@ fn normalize_labels( fn normalize_predicates(
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
-) -> Vec<UnconditionalStatement> {
+) -> Result<Vec<UnconditionalStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
- let if_true = id_def.new_id(None);
- let if_false = id_def.new_id(None);
+ let if_true = id_def.new_non_variable(None);
+ let if_false = id_def.new_non_variable(None);
let folded_bra = match &inst {
ast::Instruction::Bra(_, arg) => Some(arg.src),
_ => None,
@@ -1940,20 +1910,25 @@ fn normalize_predicates( }
Statement::Variable(var) => result.push(Statement::Variable(var)),
// Blocks are flattened when resolving ids
- _ => unreachable!(),
+ _ => return Err(TranslateError::Unreachable),
}
}
- result
+ Ok(result)
}
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
- id_def: &mut MutableNumericIdResolver,
+ id_def: &mut NumericIdResolver,
+ ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
fn_decl: &mut SpirvMethodDecl,
) -> Result<Vec<TypedStatement>, TranslateError> {
+ let is_func = match ast_fn_decl {
+ ast::MethodDecl::Func(..) => true,
+ ast::MethodDecl::Kernel { .. } => false,
+ };
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
- match type_to_variable_type(&arg.v_type)? {
+ match type_to_variable_type(&arg.v_type, is_func)? {
Some(var_type) => {
result.push(Statement::Variable(ast::Variable {
align: arg.align,
@@ -1965,25 +1940,25 @@ fn insert_mem_ssa_statements<'a, 'b>( None => return Err(TranslateError::Unreachable),
}
}
- for arg in fn_decl.input.iter_mut() {
- match type_to_variable_type(&arg.v_type)? {
+ for spirv_arg in fn_decl.input.iter_mut() {
+ match type_to_variable_type(&spirv_arg.v_type, is_func)? {
Some(var_type) => {
- let typ = arg.v_type.clone();
- let new_id = id_def.new_id(typ.clone());
+ let typ = spirv_arg.v_type.clone();
+ let new_id = id_def.new_non_variable(Some(typ.clone()));
result.push(Statement::Variable(ast::Variable {
- align: arg.align,
+ align: spirv_arg.align,
v_type: var_type,
- name: arg.name,
- array_init: arg.array_init.clone(),
+ name: spirv_arg.name,
+ array_init: spirv_arg.array_init.clone(),
}));
result.push(Statement::StoreVar(
ast::Arg2St {
- src1: arg.name,
+ src1: spirv_arg.name,
src2: new_id,
},
typ,
));
- arg.name = new_id;
+ spirv_arg.name = new_id;
}
None => {}
}
@@ -1997,8 +1972,8 @@ fn insert_mem_ssa_statements<'a, 'b>( ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
if let &[out_param] = &fn_decl.output.as_slice() {
- let typ = id_def.get_typed(out_param.name)?;
- let new_id = id_def.new_id(typ.clone());
+ let (typ, _) = id_def.get_typed(out_param.name)?;
+ let new_id = id_def.new_non_variable(Some(typ.clone()));
result.push(Statement::LoadVar(
ast::Arg2 {
dst: new_id,
@@ -2014,7 +1989,8 @@ fn insert_mem_ssa_statements<'a, 'b>( inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
Statement::Conditional(mut bra) => {
- let generated_id = id_def.new_id(ast::Type::Scalar(ast::ScalarType::Pred));
+ let generated_id =
+ id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
@@ -2025,21 +2001,23 @@ fn insert_mem_ssa_statements<'a, 'b>( bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
+ Statement::Conversion(conv) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, conv)?
+ }
+ Statement::PtrAccess(ptr_access) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
+ }
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
- Statement::LoadVar(_, _)
- | Statement::StoreVar(_, _)
- | Statement::Conversion(_)
- | Statement::RetValue(_, _)
- | Statement::Constant(_)
- | Statement::Undef(_, _)
- | Statement::PtrAdd { .. } => {}
- Statement::Composite(_) => todo!(),
+ _ => return Err(TranslateError::Unreachable),
}
}
Ok(result)
}
-fn type_to_variable_type(t: &ast::Type) -> Result<Option<ast::VariableType>, TranslateError> {
+fn type_to_variable_type(
+ t: &ast::Type,
+ is_func: bool,
+) -> Result<Option<ast::VariableType>, TranslateError> {
Ok(match t {
ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
@@ -2054,7 +2032,22 @@ fn type_to_variable_type(t: &ast::Type) -> Result<Option<ast::VariableType>, Tra .map_err(|_| TranslateError::MismatchedType)?,
len.clone(),
))),
- ast::Type::Pointer(_, _) => None,
+ ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
+ if is_func {
+ return Ok(None);
+ }
+ Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
+ scalar_type
+ .clone()
+ .try_into()
+ .map_err(|_| TranslateError::Unreachable)?,
+ (*space)
+ .try_into()
+ .map_err(|_| TranslateError::Unreachable)?,
+ )))
+ }
+ ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
+ _ => return Err(TranslateError::Unreachable),
})
}
@@ -2105,34 +2098,28 @@ impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded }
fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
- id_def: &mut MutableNumericIdResolver,
+ id_def: &mut NumericIdResolver,
result: &mut Vec<TypedStatement>,
stmt: F,
) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
- let new_statement =
- stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, instr_type| {
- if instr_type.is_none() || desc.sema == ArgumentSemantics::RegisterPointer {
+ let new_statement = stmt.visit_variable(
+ &mut |desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ if expected_type.is_none() {
return Ok(desc.op);
- }
- let id_type = match (id_def.get_typed(desc.op)?, desc.sema) {
- (_, ArgumentSemantics::Address) => return Ok(desc.op),
- (t, ArgumentSemantics::RegisterPointer)
- | (t, ArgumentSemantics::Default)
- | (t, ArgumentSemantics::DefaultRelaxed)
- | (t, ArgumentSemantics::PhysicalPointer) => t,
};
- if let ast::Type::Array(_, _) = id_type {
+ let (var_type, is_variable) = id_def.get_typed(desc.op)?;
+ if !is_variable {
return Ok(desc.op);
}
- let generated_id = id_def.new_id(id_type.clone());
+ let generated_id = id_def.new_non_variable(Some(var_type.clone()));
if !desc.is_dst {
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
src: desc.op,
},
- id_type,
+ var_type,
));
} else {
post_statements.push(Statement::StoreVar(
@@ -2140,11 +2127,12 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( src1: desc.op,
src2: generated_id,
},
- id_type,
+ var_type,
));
}
Ok(generated_id)
- })?;
+ },
+ )?;
result.push(new_statement);
result.append(&mut post_statements);
Ok(())
@@ -2180,65 +2168,21 @@ fn expand_arguments<'a, 'b>( name,
array_init,
})),
- Statement::PtrAdd {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- constant_src,
- } => {
+ Statement::PtrAccess(ptr_access) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let sema = match state_space {
- ast::LdStateSpace::Const
- | ast::LdStateSpace::Global
- | ast::LdStateSpace::Shared
- | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
- ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
- ArgumentSemantics::RegisterPointer
- }
- };
- let ptr_type = ast::Type::Pointer(underlying_type.clone(), state_space);
- let new_dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema,
- },
- Some(&ptr_type),
- )?;
- let new_ptr_src = visitor.id(
- ArgumentDescriptor {
- op: ptr_src,
- is_dst: false,
- sema,
- },
- Some(&ptr_type),
- )?;
- let new_constant_src = visitor.id(
- ArgumentDescriptor {
- op: constant_src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Scalar(ast::ScalarType::S64)),
- )?;
- result.push(Statement::PtrAdd {
- underlying_type,
- state_space,
- dst: new_dst,
- ptr_src: new_ptr_src,
- constant_src: new_constant_src,
- })
+ let (new_inst, post_stmts) = (ptr_access.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::PtrAccess(new_inst));
+ result.extend(post_stmts);
}
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Composite(_)
- | Statement::Conversion(_)
- | Statement::Constant(_)
- | Statement::Undef(_, _) => unreachable!(),
+ Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
+ Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => {
+ return Err(TranslateError::Unreachable)
+ }
}
}
Ok(result)
@@ -2270,7 +2214,8 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { scalar_sema_override: Option<ArgumentSemantics>,
composite_src: (spirv::Word, u8),
) -> spirv::Word {
- let new_id = scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Scalar(typ.0)));
+ let new_id =
+ scalar_dst.unwrap_or_else(|| id_def.new_non_variable(ast::Type::Scalar(typ.0)));
func.push(Statement::Composite(CompositeRead {
typ: typ.0,
dst: new_id,
@@ -2301,20 +2246,20 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ast::Type::Pointer(underlying_type, state_space) => {
let reg_typ = self.id_def.get_typed(reg)?;
if let ast::Type::Pointer(_, _) = reg_typ {
- let id_constant_stmt = self.id_def.new_id(typ.clone());
+ let id_constant_stmt = self.id_def.new_non_variable(typ.clone());
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
- let dst = self.id_def.new_id(typ.clone());
- self.func.push(Statement::PtrAdd {
+ let dst = self.id_def.new_non_variable(typ.clone());
+ self.func.push(Statement::PtrAccess(PtrAccess {
underlying_type: underlying_type.clone(),
state_space: *state_space,
dst,
ptr_src: reg,
- constant_src: id_constant_stmt,
- });
+ offset_src: id_constant_stmt,
+ }));
return Ok(dst);
} else {
add_type = self.id_def.get_typed(reg)?;
@@ -2346,8 +2291,8 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } else {
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
};
- let id_constant_stmt = self.id_def.new_id(add_type.clone());
- let result_id = self.id_def.new_id(add_type);
+ let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
+ let result_id = self.id_def.new_non_variable(add_type);
// TODO: check for edge cases around min value/max value/wrapping
if offset < 0 && kind != ScalarKind::Signed {
self.func.push(Statement::Constant(ConstantDefinition {
@@ -2395,7 +2340,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } else {
todo!()
};
- let id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
@@ -2430,10 +2375,10 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ) -> Result<spirv::Word, TranslateError> {
let (scalar_type, vec_len) = typ.get_vector()?;
if !desc.is_dst {
- let mut new_id = self.id_def.new_id(typ.clone());
+ let mut new_id = self.id_def.new_non_variable(typ.clone());
self.func.push(Statement::Undef(typ.clone(), new_id));
for (idx, id) in desc.op.iter().enumerate() {
- let newer_id = self.id_def.new_id(typ.clone());
+ let newer_id = self.id_def.new_non_variable(typ.clone());
self.func.push(Statement::Instruction(ast::Instruction::Mov(
ast::MovDetails {
typ: ast::Type::Scalar(scalar_type),
@@ -2452,7 +2397,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { }
Ok(new_id)
} else {
- let new_id = self.id_def.new_id(typ.clone());
+ let new_id = self.id_def.new_non_variable(typ.clone());
for (idx, id) in desc.op.iter().enumerate() {
Self::insert_composite_read(
&mut self.post_stmts,
@@ -2597,13 +2542,13 @@ fn insert_implicit_conversions( should_bitcast_wrapper,
None,
)?,
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
- constant_src,
- } => {
+ offset_src: constant_src,
+ }) => {
let visit_desc = VisitArgumentDescriptor {
desc: ArgumentDescriptor {
op: ptr_src,
@@ -2611,12 +2556,14 @@ fn insert_implicit_conversions( sema: ArgumentSemantics::PhysicalPointer,
},
typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
- stmt_ctor: |new_ptr_src| Statement::PtrAdd {
- underlying_type,
- state_space,
- dst,
- ptr_src: new_ptr_src,
- constant_src,
+ stmt_ctor: |new_ptr_src| {
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src: new_ptr_src,
+ offset_src: constant_src,
+ })
},
};
insert_implicit_conversions_impl(
@@ -2628,6 +2575,7 @@ fn insert_implicit_conversions( )?;
}
s @ Statement::Conditional(_)
+ | s @ Statement::Conversion(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
@@ -2635,7 +2583,6 @@ fn insert_implicit_conversions( | s @ Statement::StoreVar(_, _)
| s @ Statement::Undef(_, _)
| s @ Statement::RetValue(_, _) => result.push(s),
- Statement::Conversion(_) => unreachable!(),
}
}
Ok(result)
@@ -2688,7 +2635,7 @@ fn insert_implicit_conversions_impl( };
let mut from = instr_type.clone();
let mut to = operand_type;
- let mut src = id_def.new_id(instr_type.clone());
+ let mut src = id_def.new_non_variable(instr_type.clone());
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
@@ -2701,6 +2648,8 @@ fn insert_implicit_conversions_impl( from,
to,
kind: conv_kind,
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
}));
result
}
@@ -3242,21 +3191,33 @@ fn emit_function_body_ops( let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
builder.undef(result_type, Some(*id));
}
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
- constant_src,
- } => {
- let s64_type = map.get_or_add_scalar(builder, ast::ScalarType::S64);
- let ptr_as_s64 = builder.bitcast(s64_type, None, *ptr_src)?;
- let added_ptr = builder.i_add(s64_type, None, ptr_as_s64, *constant_src)?;
+ offset_src,
+ }) => {
+ let u8_pointer = map.get_or_add(
+ builder,
+ SpirvType::from(ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ *state_space,
+ )),
+ );
let result_type = map.get_or_add(
builder,
SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)),
);
- builder.bitcast(result_type, Some(*dst), added_ptr)?;
+ let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
+ let temp = builder.in_bounds_ptr_access_chain(
+ u8_pointer,
+ None,
+ ptr_src_u8,
+ *offset_src,
+ &[],
+ )?;
+ builder.bitcast(result_type, Some(*dst), temp)?;
}
}
}
@@ -3745,6 +3706,8 @@ fn emit_cvt( src_t.kind(),
)),
kind: ConversionKind::Default,
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
};
emit_implicit_conversion(builder, map, &cv)?;
new_dst
@@ -4117,6 +4080,8 @@ fn emit_implicit_conversion( from: wide_bit_type,
to: cv.to.clone(),
kind: ConversionKind::Default,
+ src_sema: cv.src_sema,
+ dst_sema: cv.dst_sema,
},
)?;
}
@@ -4156,7 +4121,7 @@ fn normalize_identifiers<'a, 'b>( for s in func.iter() {
match s {
ast::Statement::Label(id) => {
- id_defs.add_def(*id, None);
+ id_defs.add_def(*id, None, false);
}
_ => (),
}
@@ -4189,23 +4154,35 @@ fn expand_map_variables<'a, 'b>( i.map_variable(&mut |id| id_defs.get_id(id))?,
))),
ast::Statement::Variable(var) => {
- let ss = match var.var.v_type {
- ast::VariableType::Reg(_) => StateSpace::Reg,
- ast::VariableType::Global(_) => StateSpace::Global,
- ast::VariableType::Shared(_) => StateSpace::Shared,
- ast::VariableType::Param(_) => StateSpace::ParamReg,
- ast::VariableType::Local(_) => StateSpace::Local,
- };
let mut var_type = ast::Type::from(var.var.v_type.clone());
+ let mut is_variable = false;
var_type = match var.var.v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Shared(_) => var_type,
- ast::VariableType::Global(_) => var_type.pointer_to(ast::LdStateSpace::Global)?,
- ast::VariableType::Param(_) => var_type.pointer_to(ast::LdStateSpace::Param)?,
- ast::VariableType::Local(_) => var_type.pointer_to(ast::LdStateSpace::Local)?,
+ ast::VariableType::Reg(_) => {
+ is_variable = true;
+ var_type
+ }
+ ast::VariableType::Shared(_) => {
+ // If it's a pointer it will be translated to a method parameter later
+ if let ast::Type::Pointer(..) = var_type {
+ is_variable = true;
+ var_type
+ } else {
+ var_type.param_pointer_to(ast::LdStateSpace::Shared)?
+ }
+ }
+ ast::VariableType::Global(_) => {
+ var_type.param_pointer_to(ast::LdStateSpace::Global)?
+ }
+ ast::VariableType::Param(_) => {
+ var_type.param_pointer_to(ast::LdStateSpace::Param)?
+ }
+ ast::VariableType::Local(_) => {
+ var_type.param_pointer_to(ast::LdStateSpace::Local)?
+ }
};
match var.count {
Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, ss, var_type) {
+ for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
@@ -4215,7 +4192,7 @@ fn expand_map_variables<'a, 'b>( }
}
None => {
- let new_id = id_defs.add_def(var.var.name, Some((ss, var_type)));
+ let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable);
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
@@ -4229,6 +4206,384 @@ fn expand_map_variables<'a, 'b>( Ok(())
}
+// TODO: detect more patterns (mov, call via reg, call via param)
+// TODO: don't convert to ptr if the register is not ultimately used for ld/st
+// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
+// argument expansion
+// TODO: propagate through calls?
+fn convert_to_stateful_memory_access<'a>(
+ func_args: &mut SpirvMethodDecl,
+ func_body: Vec<TypedStatement>,
+ id_defs: &mut NumericIdResolver<'a>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let func_args_64bit = func_args
+ .input
+ .iter()
+ .filter_map(|arg| match arg.v_type {
+ ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
+ _ => None,
+ })
+ .collect::<HashSet<_>>();
+ let mut stateful_markers = Vec::new();
+ let mut stateful_init_reg = MultiHashMap::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Cvta(
+ ast::CvtaDetails {
+ to: ast::CvtaStateSpace::Global,
+ size: ast::CvtaSize::U64,
+ from: ast::CvtaStateSpace::Generic,
+ },
+ arg,
+ )) => {
+ if let Some(src) = arg.src.underlying() {
+ if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, arg.dst) {
+ stateful_markers.push((arg.dst, *src));
+ }
+ }
+ }
+ Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ typ: ast::LdStType::Scalar(ast::LdStScalarType::U64),
+ ..
+ },
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ typ: ast::LdStType::Scalar(ast::LdStScalarType::S64),
+ ..
+ },
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ typ: ast::LdStType::Scalar(ast::LdStScalarType::B64),
+ ..
+ },
+ arg,
+ )) => {
+ if let (ast::IdOrVector::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) {
+ if func_args_64bit.contains(src) {
+ multi_hash_map_append(&mut stateful_init_reg, *dst, *src);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ let mut func_args_ptr = HashSet::new();
+ let mut regs_ptr_current = HashSet::new();
+ for (dst, src) in stateful_markers {
+ if let Some(func_args) = stateful_init_reg.get(&src) {
+ for a in func_args {
+ func_args_ptr.insert(*a);
+ regs_ptr_current.insert(src);
+ regs_ptr_current.insert(dst);
+ }
+ }
+ }
+ // BTreeSet here to have a stable order of iteration,
+ // unfortunately our tests rely on it
+ let mut regs_ptr_seen = BTreeSet::new();
+ while regs_ptr_current.len() > 0 {
+ let mut regs_ptr_new = HashSet::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ )) => {
+ if let Some(src1) = arg.src1.underlying() {
+ if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) {
+ regs_ptr_new.insert(arg.dst);
+ }
+ } else if let Some(src2) = arg.src2.underlying() {
+ if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) {
+ regs_ptr_new.insert(arg.dst);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ for id in regs_ptr_current {
+ regs_ptr_seen.insert(id);
+ }
+ regs_ptr_current = regs_ptr_new;
+ }
+ drop(regs_ptr_current);
+ let mut remapped_ids = HashMap::new();
+ let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
+ for reg in regs_ptr_seen {
+ let new_id = id_defs.new_variable(ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ));
+ result.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: new_id,
+ array_init: Vec::new(),
+ v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
+ ast::SizedScalarType::U8,
+ ast::PointerStateSpace::Global,
+ )),
+ }));
+ remapped_ids.insert(reg, new_id);
+ }
+ for statement in func_body {
+ match statement {
+ l @ Statement::Label(_) => result.push(l),
+ c @ Statement::Conditional(_) => result.push(c),
+ Statement::Variable(var) => {
+ if !remapped_ids.contains_key(&var.name) {
+ result.push(Statement::Variable(var));
+ }
+ }
+ Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ )) if is_add_ptr_direct(&remapped_ids, &arg) => {
+ let (ptr, offset) = match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => {
+ (remapped_ids.get(src1).unwrap(), arg.src2)
+ }
+ Some(src2) if remapped_ids.contains_key(src2) => {
+ (remapped_ids.get(src2).unwrap(), arg.src1)
+ }
+ _ => return Err(TranslateError::Unreachable),
+ };
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
+ state_space: ast::LdStateSpace::Global,
+ dst: *remapped_ids.get(&arg.dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: offset,
+ }))
+ }
+ Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ )) if is_add_ptr_direct(&remapped_ids, &arg) => {
+ let (ptr, offset) = match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => {
+ (remapped_ids.get(src1).unwrap(), arg.src2)
+ }
+ Some(src2) if remapped_ids.contains_key(src2) => {
+ (remapped_ids.get(src2).unwrap(), arg.src1)
+ }
+ _ => return Err(TranslateError::Unreachable),
+ };
+ let offset_neg =
+ id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64)));
+ result.push(Statement::Instruction(ast::Instruction::Neg(
+ ast::NegDetails {
+ typ: ast::ScalarType::S64,
+ flush_to_zero: None,
+ },
+ ast::Arg2 {
+ src: offset,
+ dst: offset_neg,
+ },
+ )));
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
+ state_space: ast::LdStateSpace::Global,
+ dst: *remapped_ids.get(&arg.dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: ast::Operand::Reg(offset_neg),
+ }))
+ }
+ Statement::Instruction(inst) => {
+ let mut post_statements = Vec::new();
+ let new_statement = inst.visit_variable(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ expected_type,
+ )
+ },
+ )?;
+ result.push(new_statement);
+ for s in post_statements {
+ result.push(s);
+ }
+ }
+ Statement::Call(call) => {
+ let mut post_statements = Vec::new();
+ let new_statement = call.visit_variable(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ expected_type,
+ )
+ },
+ )?;
+ result.push(new_statement);
+ for s in post_statements {
+ result.push(s);
+ }
+ }
+ _ => return Err(TranslateError::Unreachable),
+ }
+ }
+ for arg in func_args.input.iter_mut() {
+ if func_args_ptr.contains(&arg.name) {
+ arg.v_type = ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ );
+ }
+ }
+ Ok(result)
+}
+
+fn convert_to_stateful_memory_access_postprocess(
+ id_defs: &mut NumericIdResolver,
+ remapped_ids: &HashMap<spirv::Word, spirv::Word>,
+ func_args_ptr: &HashSet<spirv::Word>,
+ result: &mut Vec<TypedStatement>,
+ post_statements: &mut Vec<TypedStatement>,
+ arg_desc: ArgumentDescriptor<spirv::Word>,
+ expected_type: Option<&ast::Type>,
+) -> Result<spirv::Word, TranslateError> {
+ Ok(match remapped_ids.get(&arg_desc.op) {
+ Some(new_id) => {
+ // We skip conversion here to trigger PtrAcces in a later pass
+ let old_type = match expected_type {
+ Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
+ _ => id_defs.get_typed(arg_desc.op)?.0,
+ };
+ let old_type_clone = old_type.clone();
+ let converting_id = id_defs.new_non_variable(Some(old_type_clone));
+ if arg_desc.is_dst {
+ post_statements.push(Statement::Conversion(ImplicitConversion {
+ src: converting_id,
+ dst: *new_id,
+ from: old_type,
+ to: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: arg_desc.sema,
+ }));
+ converting_id
+ } else {
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ to: old_type,
+ kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ converting_id
+ }
+ }
+ None => match func_args_ptr.get(&arg_desc.op) {
+ Some(new_id) => {
+ if arg_desc.is_dst {
+ return Err(TranslateError::Unreachable);
+ }
+ // We skip conversion here to trigger PtrAcces in a later pass
+ let old_type = match expected_type {
+ Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
+ _ => id_defs.get_typed(arg_desc.op)?.0,
+ };
+ let old_type_clone = old_type.clone();
+ let converting_id = id_defs.new_non_variable(Some(old_type));
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global),
+ ast::LdStateSpace::Param,
+ ),
+ to: old_type_clone,
+ kind: ConversionKind::PtrToPtr { spirv_ptr: false },
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ converting_id
+ }
+ None => arg_desc.op,
+ },
+ })
+}
+
+fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
+ if !remapped_ids.contains_key(&arg.dst) {
+ return false;
+ }
+ match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => true,
+ Some(src2) if remapped_ids.contains_key(src2) => true,
+ _ => false,
+ }
+}
+
+fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
+ match id_defs.get_typed(id) {
+ Ok((ast::Type::Scalar(ast::ScalarType::U64), _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::S64), _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true,
+ _ => false,
+ }
+}
+
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister {
Tid,
@@ -4270,7 +4625,7 @@ impl PtxSpecialRegister { struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
- variables_type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
+ variables_type_check: HashMap<u32, Option<(ast::Type, bool)>>,
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
fns: HashMap<spirv::Word, FnDecl>,
}
@@ -4295,15 +4650,16 @@ impl<'a> GlobalStringIdResolver<'a> { self.get_or_add_impl(id, None)
}
- fn get_or_add_def_typed(&mut self, id: &'a str, typ: (StateSpace, ast::Type)) -> spirv::Word {
- self.get_or_add_impl(id, Some(typ))
- }
-
- fn get_or_add_impl(
+ fn get_or_add_def_typed(
&mut self,
id: &'a str,
- typ: Option<(StateSpace, ast::Type)>,
+ typ: ast::Type,
+ is_variable: bool,
) -> spirv::Word {
+ self.get_or_add_impl(id, Some((typ, is_variable)))
+ }
+
+ fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word {
let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
@@ -4399,10 +4755,10 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
- global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
- type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
+ type_check: HashMap<u32, Option<(ast::Type, bool)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@@ -4452,13 +4808,14 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { }
}
- fn add_def(&mut self, id: &'a str, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
+ fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>, is_variable: bool) -> spirv::Word {
let numeric_id = *self.current_id;
self.variables
.last_mut()
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
- self.type_check.insert(numeric_id, typ);
+ self.type_check
+ .insert(numeric_id, typ.map(|t| (t, is_variable)));
*self.current_id += 1;
numeric_id
}
@@ -4468,8 +4825,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { &mut self,
base_id: &'a str,
count: u32,
- ss: StateSpace,
typ: ast::Type,
+ is_variable: bool,
) -> impl Iterator<Item = spirv::Word> {
let numeric_id = *self.current_id;
for i in 0..count {
@@ -4478,7 +4835,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
self.type_check
- .insert(numeric_id + i, Some((ss, typ.clone())));
+ .insert(numeric_id + i, Some((typ.clone(), is_variable)));
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -4487,8 +4844,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
- global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
- type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, bool)>>,
special_registers: HashMap<spirv::Word, PtxSpecialRegister>,
}
@@ -4497,23 +4854,32 @@ impl<'b> NumericIdResolver<'b> { MutableNumericIdResolver { base: self }
}
- fn get_typed(&self, id: spirv::Word) -> Result<(StateSpace, ast::Type), TranslateError> {
+ fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(&id) {
- Some(x) => Ok((StateSpace::Reg, x.get_type())),
+ Some(x) => Ok((x.get_type(), true)),
None => match self.global_type_check.get(&id) {
- Some(Some(x)) => Ok(x.clone()),
+ Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
},
},
}
}
- fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
+ // This is for identifiers which will be emitted later as OpVariable
+ // They are candidates for insertion of LoadVar/StoreVar
+ fn new_variable(&mut self, typ: ast::Type) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, typ);
+ self.type_check.insert(new_id, Some((typ, true)));
+ *self.current_id += 1;
+ new_id
+ }
+
+ fn new_non_variable(&mut self, typ: Option<ast::Type>) -> spirv::Word {
+ let new_id = *self.current_id;
+ self.type_check.insert(new_id, typ.map(|t| (t, false)));
*self.current_id += 1;
new_id
}
@@ -4529,11 +4895,11 @@ impl<'b> MutableNumericIdResolver<'b> { }
fn get_typed(&self, id: spirv::Word) -> Result<ast::Type, TranslateError> {
- self.base.get_typed(id).map(|(_, t)| t)
+ self.base.get_typed(id).map(|(t, _)| t)
}
- fn new_id(&mut self, typ: ast::Type) -> spirv::Word {
- self.base.new_id(Some((StateSpace::Reg, typ)))
+ fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word {
+ self.base.new_non_variable(Some(typ))
}
}
@@ -4541,101 +4907,102 @@ enum Statement<I, P: ast::ArgParams> { Label(u32),
Variable(ast::Variable<ast::VariableType, P::Id>),
Instruction(I),
+ // SPIR-V compatible replacement for PTX predicates
+ Conditional(BrachCondition),
+ Call(ResolvedCall<P>),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
- Call(ResolvedCall<P>),
Composite(CompositeRead),
- // SPIR-V compatible replacement for PTX predicates
- Conditional(BrachCondition),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
Undef(ast::Type, spirv::Word),
- PtrAdd {
- underlying_type: ast::PointerType,
- state_space: ast::LdStateSpace,
- dst: spirv::Word,
- ptr_src: spirv::Word,
- constant_src: spirv::Word,
- },
+ PtrAccess(PtrAccess<P>),
}
impl ExpandedStatement {
- fn map_id(self, f: &mut impl FnMut(spirv::Word) -> spirv::Word) -> ExpandedStatement {
+ fn map_id(self, f: &mut impl FnMut(spirv::Word, bool) -> spirv::Word) -> ExpandedStatement {
match self {
- Statement::Label(id) => Statement::Label(f(id)),
+ Statement::Label(id) => Statement::Label(f(id, false)),
Statement::Variable(mut var) => {
- var.name = f(var.name);
+ var.name = f(var.name, true);
Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| Ok(f(arg.op)))
+ .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| {
+ Ok(f(arg.op, arg.is_dst))
+ })
.unwrap(),
Statement::LoadVar(mut arg, typ) => {
- arg.dst = f(arg.dst);
- arg.src = f(arg.src);
+ arg.dst = f(arg.dst, true);
+ arg.src = f(arg.src, false);
Statement::LoadVar(arg, typ)
}
Statement::StoreVar(mut arg, typ) => {
- arg.src1 = f(arg.src1);
- arg.src2 = f(arg.src2);
+ arg.src1 = f(arg.src1, false);
+ arg.src2 = f(arg.src2, false);
Statement::StoreVar(arg, typ)
}
Statement::Call(mut call) => {
- for (id, _) in call.ret_params.iter_mut() {
- *id = f(*id);
+ for (id, typ) in call.ret_params.iter_mut() {
+ let is_dst = match typ {
+ ast::FnArgumentType::Reg(_) => true,
+ ast::FnArgumentType::Param(_) => false,
+ ast::FnArgumentType::Shared => false,
+ };
+ *id = f(*id, is_dst);
}
- call.func = f(call.func);
+ call.func = f(call.func, false);
for (id, _) in call.param_list.iter_mut() {
- *id = f(*id);
+ *id = f(*id, false);
}
Statement::Call(call)
}
Statement::Composite(mut composite) => {
- composite.dst = f(composite.dst);
- composite.src_composite = f(composite.src_composite);
+ composite.dst = f(composite.dst, true);
+ composite.src_composite = f(composite.src_composite, false);
Statement::Composite(composite)
}
Statement::Conditional(mut conditional) => {
- conditional.predicate = f(conditional.predicate);
- conditional.if_true = f(conditional.if_true);
- conditional.if_false = f(conditional.if_false);
+ conditional.predicate = f(conditional.predicate, false);
+ conditional.if_true = f(conditional.if_true, false);
+ conditional.if_false = f(conditional.if_false, false);
Statement::Conditional(conditional)
}
Statement::Conversion(mut conv) => {
- conv.dst = f(conv.dst);
- conv.src = f(conv.src);
+ conv.dst = f(conv.dst, true);
+ conv.src = f(conv.src, false);
Statement::Conversion(conv)
}
Statement::Constant(mut constant) => {
- constant.dst = f(constant.dst);
+ constant.dst = f(constant.dst, true);
Statement::Constant(constant)
}
Statement::RetValue(data, id) => {
- let id = f(id);
+ let id = f(id, false);
Statement::RetValue(data, id)
}
Statement::Undef(typ, id) => {
- let id = f(id);
+ let id = f(id, true);
Statement::Undef(typ, id)
}
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
- constant_src,
- } => {
- let dst = f(dst);
- let ptr_src = f(ptr_src);
- let constant_src = f(constant_src);
- Statement::PtrAdd {
+ offset_src: constant_src,
+ }) => {
+ let dst = f(dst, true);
+ let ptr_src = f(ptr_src, false);
+ let constant_src = f(constant_src, false);
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
- constant_src,
- }
+ offset_src: constant_src,
+ })
}
}
}
@@ -4740,6 +5107,70 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> { }
}
+impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
+ fn map<To: ArgParamsEx<Id = spirv::Word>, V: ArgumentMapVisitor<P, To>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<PtrAccess<To>, TranslateError> {
+ let sema = match self.state_space {
+ ast::LdStateSpace::Const
+ | ast::LdStateSpace::Global
+ | ast::LdStateSpace::Shared
+ | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
+ ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
+ ArgumentSemantics::RegisterPointer
+ }
+ };
+ let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space);
+ let new_dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema,
+ },
+ Some(&ptr_type),
+ )?;
+ let new_ptr_src = visitor.id(
+ ArgumentDescriptor {
+ op: self.ptr_src,
+ is_dst: false,
+ sema,
+ },
+ Some(&ptr_type),
+ )?;
+ let new_constant_src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.offset_src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(ast::ScalarType::S64),
+ )?;
+ Ok(PtrAccess {
+ underlying_type: self.underlying_type,
+ state_space: self.state_space,
+ dst: new_dst,
+ ptr_src: new_ptr_src,
+ offset_src: new_constant_src,
+ })
+ }
+}
+
+impl VisitVariable for PtrAccess<TypedArgParams> {
+ fn visit_variable<
+ 'a,
+ F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<TypedStatement, TranslateError> {
+ Ok(Statement::PtrAccess(self.map(f)?))
+ }
+}
+
pub trait ArgParamsEx: ast::ArgParams + Sized {
fn get_fn_decl<'x, 'b>(
id: &Self::Id,
@@ -5035,6 +5466,14 @@ pub struct ArgumentDescriptor<Op> { sema: ArgumentSemantics,
}
+pub struct PtrAccess<P: ast::ArgParams> {
+ underlying_type: ast::PointerType,
+ state_space: ast::LdStateSpace,
+ dst: spirv::Word,
+ ptr_src: spirv::Word,
+ offset_src: P::Operand,
+}
+
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum ArgumentSemantics {
// normal register access
@@ -5266,6 +5705,56 @@ impl VisitVariable for ast::Instruction<TypedArgParams> { fn visit_variable<
'a,
F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<TypedStatement, TranslateError> {
+ Ok(Statement::Instruction(self.map(f)?))
+ }
+}
+
+impl ImplicitConversion {
+ fn map<
+ T: ArgParamsEx<Id = spirv::Word>,
+ U: ArgParamsEx<Id = spirv::Word>,
+ V: ArgumentMapVisitor<T, U>,
+ >(
+ self,
+ visitor: &mut V,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ let new_dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: self.dst_sema,
+ },
+ Some(&self.to),
+ )?;
+ let new_src = visitor.id(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: self.src_sema,
+ },
+ Some(&self.from),
+ )?;
+ Ok(Statement::Conversion({
+ ImplicitConversion {
+ src: new_src,
+ dst: new_dst,
+ ..self
+ }
+ }))
+ }
+}
+
+impl VisitVariable for ImplicitConversion {
+ fn visit_variable<
+ 'a,
+ F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
@@ -5273,7 +5762,21 @@ impl VisitVariable for ast::Instruction<TypedArgParams> { self,
f: &mut F,
) -> Result<TypedStatement, TranslateError> {
- Ok(Statement::Instruction(self.map(f)?))
+ self.map(f)
+ }
+}
+
+impl VisitVariableExpanded for ImplicitConversion {
+ fn visit_variable_extended<
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<ExpandedStatement, TranslateError> {
+ self.map(f)
}
}
@@ -5708,6 +6211,8 @@ struct ImplicitConversion { from: ast::Type,
to: ast::Type,
kind: ConversionKind,
+ src_sema: ArgumentSemantics,
+ dst_sema: ArgumentSemantics,
}
#[derive(PartialEq, Copy, Clone)]
|