diff options
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 1 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/shared_unify_local.ptx | 43 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/shared_unify_local.spvtxt | 117 | ||||
-rw-r--r-- | ptx/src/translate.rs | 25 | ||||
-rw-r--r-- | zluda/src/impl/module.rs | 16 |
5 files changed, 187 insertions, 15 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index dfc252d..f5dfa64 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -219,6 +219,7 @@ test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); test_ptx!(activemask, [0u32], [1u32]);
test_ptx!(membar, [152731u32], [152731u32]);
test_ptx!(shared_unify_extern, [7681u64, 7682u64], [15363u64]);
+test_ptx!(shared_unify_local, [16752u64, 714u64], [17466u64]);
test_ptx!(assertfail);
test_ptx!(func_ptr);
diff --git a/ptx/src/test/spirv_run/shared_unify_local.ptx b/ptx/src/test/spirv_run/shared_unify_local.ptx new file mode 100644 index 0000000..84f3a50 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_unify_local.ptx @@ -0,0 +1,43 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.extern .shared .b32 shared_ex[];
+
+.func (.reg .b64 out) add(.reg .u64 temp2)
+{
+ .shared .align 4 .u64 shared_mod;
+ .reg .u64 temp1;
+ st.shared.u64 [shared_mod], temp2;
+ ld.shared.u64 temp1, [shared_mod];
+ ld.shared.u64 temp2, [shared_ex];
+ add.u64 out, temp2, temp1;
+ ret;
+}
+
+.func (.reg .b64 out) set_shared_temp1(.reg .b64 temp1, .reg .u64 temp2)
+{
+ st.shared.u64 [shared_ex], temp1;
+ call (out), add, (temp2);
+ ret;
+}
+
+.visible .entry shared_unify_local(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.u64 temp1, [in_addr];
+ ld.global.u64 temp2, [in_addr+8];
+ call (temp2), set_shared_temp1, (temp1, temp2);
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/shared_unify_local.spvtxt b/ptx/src/test/spirv_run/shared_unify_local.spvtxt new file mode 100644 index 0000000..dc00c2f --- /dev/null +++ b/ptx/src/test/spirv_run/shared_unify_local.spvtxt @@ -0,0 +1,117 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %64 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %31 "shared_unify_local" %1 %5 + OpExecutionMode %31 ContractionOff + OpDecorate %5 Alignment 4 + %void = OpTypeVoid + %uint = OpTypeInt 32 0 +%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint + %1 = OpVariable %_ptr_Workgroup_uint Workgroup + %ulong = OpTypeInt 64 0 +%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong + %5 = OpVariable %_ptr_Workgroup_ulong Workgroup + %70 = OpTypeFunction %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %72 = OpTypeFunction %ulong %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong + %73 = OpTypeFunction %void %ulong %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %2 = OpFunction %ulong None %70 + %7 = OpFunctionParameter %ulong + %60 = OpFunctionParameter %_ptr_Workgroup_uint + %61 = OpFunctionParameter %_ptr_Workgroup_ulong + %17 = OpLabel + %4 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + OpStore %4 %7 + %8 = OpLoad %ulong %4 + OpStore %61 %8 Aligned 8 + %9 = OpLoad %ulong %61 Aligned 8 + OpStore %6 %9 + %15 = OpBitcast %_ptr_Workgroup_ulong %60 + %10 = OpLoad %ulong %15 Aligned 8 + OpStore %4 %10 + %12 = OpLoad %ulong %4 + %13 = OpLoad %ulong %6 + %16 = OpIAdd %ulong %12 %13 + %11 = OpCopyObject %ulong %16 + OpStore %3 %11 + %14 = OpLoad %ulong %3 + OpReturnValue %14 + OpFunctionEnd + %18 = OpFunction %ulong None %72 + %22 = OpFunctionParameter %ulong + %23 = OpFunctionParameter %ulong + %62 = OpFunctionParameter %_ptr_Workgroup_uint + %63 = OpFunctionParameter %_ptr_Workgroup_ulong + %30 = OpLabel + %20 = OpVariable %_ptr_Function_ulong Function + %21 = OpVariable %_ptr_Function_ulong Function + %19 = OpVariable %_ptr_Function_ulong Function + OpStore %20 %22 + OpStore %21 %23 + %24 = OpLoad %ulong %20 + %28 = OpBitcast %_ptr_Workgroup_ulong %62 + %29 = OpCopyObject %ulong %24 + OpStore %28 %29 Aligned 8 + %26 = OpLoad %ulong %21 + %25 = OpFunctionCall %ulong %2 %26 %62 %63 + OpStore %19 %25 + %27 = OpLoad %ulong %19 + OpReturnValue %27 + OpFunctionEnd + %31 = OpFunction %void None %73 + %38 = OpFunctionParameter %ulong + %39 = OpFunctionParameter %ulong + %58 = OpLabel + %32 = OpVariable %_ptr_Function_ulong Function + %33 = OpVariable %_ptr_Function_ulong Function + %34 = OpVariable %_ptr_Function_ulong Function + %35 = OpVariable %_ptr_Function_ulong Function + %36 = OpVariable %_ptr_Function_ulong Function + %37 = OpVariable %_ptr_Function_ulong Function + OpStore %32 %38 + OpStore %33 %39 + %40 = OpLoad %ulong %32 Aligned 8 + OpStore %34 %40 + %41 = OpLoad %ulong %33 Aligned 8 + OpStore %35 %41 + %43 = OpLoad %ulong %34 + %53 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %43 + %42 = OpLoad %ulong %53 Aligned 8 + OpStore %36 %42 + %45 = OpLoad %ulong %34 + %54 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %45 + %77 = OpBitcast %_ptr_CrossWorkgroup_uchar %54 + %78 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %77 %ulong_8 + %52 = OpBitcast %_ptr_CrossWorkgroup_ulong %78 + %44 = OpLoad %ulong %52 Aligned 8 + OpStore %37 %44 + %47 = OpLoad %ulong %36 + %48 = OpLoad %ulong %37 + %56 = OpCopyObject %ulong %47 + %55 = OpFunctionCall %ulong %18 %56 %48 %1 %5 + %46 = OpCopyObject %ulong %55 + OpStore %37 %46 + %49 = OpLoad %ulong %35 + %50 = OpLoad %ulong %37 + %57 = OpConvertUToPtr %_ptr_Generic_ulong %49 + OpStore %57 %50 Aligned 8 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 165997e..db1063b 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -789,6 +789,14 @@ impl<'input> MethodsCallMap<'input> { })
}
+ fn methods(
+ &self,
+ ) -> impl Iterator<Item = (ast::MethodName<'input, spirv::Word>, &HashSet<spirv::Word>)> {
+ self.map
+ .iter()
+ .map(|(method, children)| (*method, children))
+ }
+
fn visit_callees(
&self,
method: ast::MethodName<'input, spirv::Word>,
@@ -1102,18 +1110,23 @@ fn resolve_indirect_uses_of_globals_shared<'input>( kernels_methods_call_map: &MethodsCallMap<'input>,
) -> HashMap<ast::MethodName<'input, spirv::Word>, BTreeSet<spirv::Word>> {
let mut result = HashMap::new();
- for (method, direct_globals) in methods_use_of_globals_shared.iter() {
- let mut indirect_globals = direct_globals.iter().copied().collect::<BTreeSet<_>>();
- kernels_methods_call_map.visit_callees(*method, |func| {
+ for (method, callees) in kernels_methods_call_map.methods() {
+ let mut indirect_globals = methods_use_of_globals_shared
+ .get(&method)
+ .into_iter()
+ .flatten()
+ .copied()
+ .collect::<BTreeSet<_>>();
+ for &callee in callees {
indirect_globals.extend(
methods_use_of_globals_shared
- .get(&ast::MethodName::Func(func))
+ .get(&ast::MethodName::Func(callee))
.into_iter()
.flatten()
.copied(),
);
- });
- result.insert(*method, indirect_globals);
+ }
+ result.insert(method, indirect_globals);
}
result
}
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 9732ec9..24fa88a 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -54,20 +54,16 @@ impl SpirvModule { } pub(crate) fn load(module: *mut CUmodule, fname: *const i8) -> Result<(), hipError_t> { - let length = (0..) - .position(|i| unsafe { *fname.add(i) == 0 }) - .ok_or(hipError_t::hipErrorInvalidValue)?; - let file_name = CStr::from_bytes_with_nul(unsafe { slice::from_raw_parts(fname as _, length) }) - .map_err(|_| hipError_t::hipErrorInvalidValue)?; - let valid_file_name = file_name + let file_name = unsafe { CStr::from_ptr(fname) } .to_str() .map_err(|_| hipError_t::hipErrorInvalidValue)?; - let mut file = File::open(valid_file_name).map_err(|_| hipError_t::hipErrorFileNotFound)?; + let mut file = File::open(file_name).map_err(|_| hipError_t::hipErrorFileNotFound)?; let mut file_buffer = Vec::new(); file.read_to_end(&mut file_buffer) .map_err(|_| hipError_t::hipErrorUnknown)?; - drop(file); - load_data(module, file_buffer.as_ptr() as _) + let result = load_data(module, file_buffer.as_ptr() as _); + drop(file_buffer); + result } pub(crate) fn load_data( @@ -201,6 +197,8 @@ pub(crate) fn compile_amd<'a>( .arg("-nogpulib") .arg("-mno-wavefrontsize64") .arg("-O3") + .arg("-Xclang") + .arg("-O3") .arg("-Xlinker") .arg("--no-undefined") .arg("-target") |