aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ptx/src/test/spirv_run/mod.rs2
-rw-r--r--ptx/src/test/spirv_run/shared_ptr_take_address.ptx27
-rw-r--r--ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt68
-rw-r--r--ptx/src/translate.rs35
4 files changed, 115 insertions, 17 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index f18b15c..027e891 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -137,6 +137,7 @@ test_ptx!(stateful_ld_st_simple, [121u64], [121u64]);
test_ptx!(stateful_ld_st_ntid, [123u64], [123u64]);
test_ptx!(stateful_ld_st_ntid_chain, [12651u64], [12651u64]);
test_ptx!(stateful_ld_st_ntid_sub, [96311u64], [96311u64]);
+test_ptx!(shared_ptr_take_address, [97815231u64], [97815231u64]);
struct DisplayError<T: Debug> {
err: T,
@@ -261,6 +262,7 @@ fn test_spvtxt_assert<'a>(
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
assert!(errors.len() == 0);
let spirv_module = translate::to_spirv_module(ast)?;
+ eprintln!("{}", rspirv::binary::Disassemble::disassemble(&spirv_module.spirv));
let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut());
diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.ptx b/ptx/src/test/spirv_run/shared_ptr_take_address.ptx
new file mode 100644
index 0000000..e892993
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_ptr_take_address.ptx
@@ -0,0 +1,27 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.extern .shared .align 4 .b8 shared_mem[];
+
+.visible .entry shared_ptr_take_address(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 shared_addr;
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+ mov.u64 shared_addr, shared_mem;
+
+ ld.global.u64 temp1, [in_addr];
+ st.shared.u64 [shared_addr], temp1;
+ ld.shared.u64 temp2, [shared_addr];
+ st.global.u64 [out_addr], temp2;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt
new file mode 100644
index 0000000..d77c2c8
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt
@@ -0,0 +1,68 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %33 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %2 "shared_ptr_take_address" %1
+ OpDecorate %1 Alignment 4
+ %void = OpTypeVoid
+ %uchar = OpTypeInt 8 0
+%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
+%_ptr_Workgroup__ptr_Workgroup_uchar = OpTypePointer Workgroup %_ptr_Workgroup_uchar
+ %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uchar Workgroup
+ %ulong = OpTypeInt 64 0
+ %39 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
+%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
+ %2 = OpFunction %void None %39
+ %10 = OpFunctionParameter %ulong
+ %11 = OpFunctionParameter %ulong
+ %31 = OpFunctionParameter %_ptr_Workgroup_uchar
+ %40 = OpLabel
+ %32 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ %9 = OpVariable %_ptr_Function_ulong Function
+ OpStore %32 %31
+ OpBranch %29
+ %29 = OpLabel
+ OpStore %3 %10
+ OpStore %4 %11
+ %12 = OpLoad %ulong %3
+ OpStore %5 %12
+ %13 = OpLoad %ulong %4
+ OpStore %6 %13
+ %15 = OpLoad %_ptr_Workgroup_uchar %32
+ %24 = OpConvertPtrToU %ulong %15
+ %14 = OpCopyObject %ulong %24
+ OpStore %7 %14
+ %17 = OpLoad %ulong %5
+ %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %17
+ %16 = OpLoad %ulong %25
+ OpStore %8 %16
+ %18 = OpLoad %ulong %7
+ %19 = OpLoad %ulong %8
+ %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %18
+ OpStore %26 %19
+ %21 = OpLoad %ulong %7
+ %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %21
+ %20 = OpLoad %ulong %27
+ OpStore %9 %20
+ %22 = OpLoad %ulong %6
+ %23 = OpLoad %ulong %9
+ %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %22
+ OpStore %28 %23
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 328bf30..20c3edb 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -843,24 +843,25 @@ fn replace_uses_of_shared_memory<'a>(
statement => {
let new_statement = statement.map_id(&mut |id, _| {
if let Some(typ) = extern_shared_decls.get(&id) {
- let replacement_id = new_id();
- if *typ != ast::SizedScalarType::B8 {
- result.push(Statement::Conversion(ImplicitConversion {
- src: shared_var_id,
- dst: replacement_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::B8),
- ast::LdStateSpace::Shared,
- ),
- to: ast::Type::Pointer(
- ast::PointerType::Scalar((*typ).into()),
- ast::LdStateSpace::Shared,
- ),
- kind: ConversionKind::PtrToPtr { spirv_ptr: true },
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
- }));
+ if *typ == ast::SizedScalarType::B8 {
+ return shared_var_id;
}
+ let replacement_id = new_id();
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: shared_var_id,
+ dst: replacement_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::B8),
+ ast::LdStateSpace::Shared,
+ ),
+ to: ast::Type::Pointer(
+ ast::PointerType::Scalar((*typ).into()),
+ ast::LdStateSpace::Shared,
+ ),
+ kind: ConversionKind::PtrToPtr { spirv_ptr: true },
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
+ }));
replacement_id
} else {
id