aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-09-29 21:49:47 +0000
committerAndrzej Janik <[email protected]>2021-09-29 21:49:47 +0000
commit816365e7df5d0bf6464f7718553d845e72637eff (patch)
treeefe16a5592b22f45584e994fefb5a7bb8b7550f4
parent0172dc58e52f2ac1e4d01951002a94a69b3589d0 (diff)
downloadZLUDA-816365e7df5d0bf6464f7718553d845e72637eff.tar.gz
ZLUDA-816365e7df5d0bf6464f7718553d845e72637eff.zip
Fix shared munging pass and add fix cuModuleLoadData
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/shared_unify_local.ptx43
-rw-r--r--ptx/src/test/spirv_run/shared_unify_local.spvtxt117
-rw-r--r--ptx/src/translate.rs25
-rw-r--r--zluda/src/impl/module.rs16
5 files changed, 187 insertions, 15 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index dfc252d..f5dfa64 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -219,6 +219,7 @@ test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
test_ptx!(activemask, [0u32], [1u32]);
test_ptx!(membar, [152731u32], [152731u32]);
test_ptx!(shared_unify_extern, [7681u64, 7682u64], [15363u64]);
+test_ptx!(shared_unify_local, [16752u64, 714u64], [17466u64]);
test_ptx!(assertfail);
test_ptx!(func_ptr);
diff --git a/ptx/src/test/spirv_run/shared_unify_local.ptx b/ptx/src/test/spirv_run/shared_unify_local.ptx
new file mode 100644
index 0000000..84f3a50
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_unify_local.ptx
@@ -0,0 +1,43 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.extern .shared .b32 shared_ex[];
+
+.func (.reg .b64 out) add(.reg .u64 temp2)
+{
+ .shared .align 4 .u64 shared_mod;
+ .reg .u64 temp1;
+ st.shared.u64 [shared_mod], temp2;
+ ld.shared.u64 temp1, [shared_mod];
+ ld.shared.u64 temp2, [shared_ex];
+ add.u64 out, temp2, temp1;
+ ret;
+}
+
+.func (.reg .b64 out) set_shared_temp1(.reg .b64 temp1, .reg .u64 temp2)
+{
+ st.shared.u64 [shared_ex], temp1;
+ call (out), add, (temp2);
+ ret;
+}
+
+.visible .entry shared_unify_local(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.u64 temp1, [in_addr];
+ ld.global.u64 temp2, [in_addr+8];
+ call (temp2), set_shared_temp1, (temp1, temp2);
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/shared_unify_local.spvtxt b/ptx/src/test/spirv_run/shared_unify_local.spvtxt
new file mode 100644
index 0000000..dc00c2f
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_unify_local.spvtxt
@@ -0,0 +1,117 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ OpCapability DenormFlushToZero
+ OpExtension "SPV_KHR_float_controls"
+ OpExtension "SPV_KHR_no_integer_wrap_decoration"
+ %64 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %31 "shared_unify_local" %1 %5
+ OpExecutionMode %31 ContractionOff
+ OpDecorate %5 Alignment 4
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
+ %1 = OpVariable %_ptr_Workgroup_uint Workgroup
+ %ulong = OpTypeInt 64 0
+%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
+ %5 = OpVariable %_ptr_Workgroup_ulong Workgroup
+ %70 = OpTypeFunction %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %72 = OpTypeFunction %ulong %ulong %ulong %_ptr_Workgroup_uint %_ptr_Workgroup_ulong
+ %73 = OpTypeFunction %void %ulong %ulong
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %ulong_8 = OpConstant %ulong 8
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %2 = OpFunction %ulong None %70
+ %7 = OpFunctionParameter %ulong
+ %60 = OpFunctionParameter %_ptr_Workgroup_uint
+ %61 = OpFunctionParameter %_ptr_Workgroup_ulong
+ %17 = OpLabel
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ OpStore %4 %7
+ %8 = OpLoad %ulong %4
+ OpStore %61 %8 Aligned 8
+ %9 = OpLoad %ulong %61 Aligned 8
+ OpStore %6 %9
+ %15 = OpBitcast %_ptr_Workgroup_ulong %60
+ %10 = OpLoad %ulong %15 Aligned 8
+ OpStore %4 %10
+ %12 = OpLoad %ulong %4
+ %13 = OpLoad %ulong %6
+ %16 = OpIAdd %ulong %12 %13
+ %11 = OpCopyObject %ulong %16
+ OpStore %3 %11
+ %14 = OpLoad %ulong %3
+ OpReturnValue %14
+ OpFunctionEnd
+ %18 = OpFunction %ulong None %72
+ %22 = OpFunctionParameter %ulong
+ %23 = OpFunctionParameter %ulong
+ %62 = OpFunctionParameter %_ptr_Workgroup_uint
+ %63 = OpFunctionParameter %_ptr_Workgroup_ulong
+ %30 = OpLabel
+ %20 = OpVariable %_ptr_Function_ulong Function
+ %21 = OpVariable %_ptr_Function_ulong Function
+ %19 = OpVariable %_ptr_Function_ulong Function
+ OpStore %20 %22
+ OpStore %21 %23
+ %24 = OpLoad %ulong %20
+ %28 = OpBitcast %_ptr_Workgroup_ulong %62
+ %29 = OpCopyObject %ulong %24
+ OpStore %28 %29 Aligned 8
+ %26 = OpLoad %ulong %21
+ %25 = OpFunctionCall %ulong %2 %26 %62 %63
+ OpStore %19 %25
+ %27 = OpLoad %ulong %19
+ OpReturnValue %27
+ OpFunctionEnd
+ %31 = OpFunction %void None %73
+ %38 = OpFunctionParameter %ulong
+ %39 = OpFunctionParameter %ulong
+ %58 = OpLabel
+ %32 = OpVariable %_ptr_Function_ulong Function
+ %33 = OpVariable %_ptr_Function_ulong Function
+ %34 = OpVariable %_ptr_Function_ulong Function
+ %35 = OpVariable %_ptr_Function_ulong Function
+ %36 = OpVariable %_ptr_Function_ulong Function
+ %37 = OpVariable %_ptr_Function_ulong Function
+ OpStore %32 %38
+ OpStore %33 %39
+ %40 = OpLoad %ulong %32 Aligned 8
+ OpStore %34 %40
+ %41 = OpLoad %ulong %33 Aligned 8
+ OpStore %35 %41
+ %43 = OpLoad %ulong %34
+ %53 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %43
+ %42 = OpLoad %ulong %53 Aligned 8
+ OpStore %36 %42
+ %45 = OpLoad %ulong %34
+ %54 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %45
+ %77 = OpBitcast %_ptr_CrossWorkgroup_uchar %54
+ %78 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %77 %ulong_8
+ %52 = OpBitcast %_ptr_CrossWorkgroup_ulong %78
+ %44 = OpLoad %ulong %52 Aligned 8
+ OpStore %37 %44
+ %47 = OpLoad %ulong %36
+ %48 = OpLoad %ulong %37
+ %56 = OpCopyObject %ulong %47
+ %55 = OpFunctionCall %ulong %18 %56 %48 %1 %5
+ %46 = OpCopyObject %ulong %55
+ OpStore %37 %46
+ %49 = OpLoad %ulong %35
+ %50 = OpLoad %ulong %37
+ %57 = OpConvertUToPtr %_ptr_Generic_ulong %49
+ OpStore %57 %50 Aligned 8
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 165997e..db1063b 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -789,6 +789,14 @@ impl<'input> MethodsCallMap<'input> {
})
}
+ fn methods(
+ &self,
+ ) -> impl Iterator<Item = (ast::MethodName<'input, spirv::Word>, &HashSet<spirv::Word>)> {
+ self.map
+ .iter()
+ .map(|(method, children)| (*method, children))
+ }
+
fn visit_callees(
&self,
method: ast::MethodName<'input, spirv::Word>,
@@ -1102,18 +1110,23 @@ fn resolve_indirect_uses_of_globals_shared<'input>(
kernels_methods_call_map: &MethodsCallMap<'input>,
) -> HashMap<ast::MethodName<'input, spirv::Word>, BTreeSet<spirv::Word>> {
let mut result = HashMap::new();
- for (method, direct_globals) in methods_use_of_globals_shared.iter() {
- let mut indirect_globals = direct_globals.iter().copied().collect::<BTreeSet<_>>();
- kernels_methods_call_map.visit_callees(*method, |func| {
+ for (method, callees) in kernels_methods_call_map.methods() {
+ let mut indirect_globals = methods_use_of_globals_shared
+ .get(&method)
+ .into_iter()
+ .flatten()
+ .copied()
+ .collect::<BTreeSet<_>>();
+ for &callee in callees {
indirect_globals.extend(
methods_use_of_globals_shared
- .get(&ast::MethodName::Func(func))
+ .get(&ast::MethodName::Func(callee))
.into_iter()
.flatten()
.copied(),
);
- });
- result.insert(*method, indirect_globals);
+ }
+ result.insert(method, indirect_globals);
}
result
}
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs
index 9732ec9..24fa88a 100644
--- a/zluda/src/impl/module.rs
+++ b/zluda/src/impl/module.rs
@@ -54,20 +54,16 @@ impl SpirvModule {
}
pub(crate) fn load(module: *mut CUmodule, fname: *const i8) -> Result<(), hipError_t> {
- let length = (0..)
- .position(|i| unsafe { *fname.add(i) == 0 })
- .ok_or(hipError_t::hipErrorInvalidValue)?;
- let file_name = CStr::from_bytes_with_nul(unsafe { slice::from_raw_parts(fname as _, length) })
- .map_err(|_| hipError_t::hipErrorInvalidValue)?;
- let valid_file_name = file_name
+ let file_name = unsafe { CStr::from_ptr(fname) }
.to_str()
.map_err(|_| hipError_t::hipErrorInvalidValue)?;
- let mut file = File::open(valid_file_name).map_err(|_| hipError_t::hipErrorFileNotFound)?;
+ let mut file = File::open(file_name).map_err(|_| hipError_t::hipErrorFileNotFound)?;
let mut file_buffer = Vec::new();
file.read_to_end(&mut file_buffer)
.map_err(|_| hipError_t::hipErrorUnknown)?;
- drop(file);
- load_data(module, file_buffer.as_ptr() as _)
+ let result = load_data(module, file_buffer.as_ptr() as _);
+ drop(file_buffer);
+ result
}
pub(crate) fn load_data(
@@ -201,6 +197,8 @@ pub(crate) fn compile_amd<'a>(
.arg("-nogpulib")
.arg("-mno-wavefrontsize64")
.arg("-O3")
+ .arg("-Xclang")
+ .arg("-O3")
.arg("-Xlinker")
.arg("--no-undefined")
.arg("-target")