summaryrefslogtreecommitdiffhomepage
path: root/ptx/src/test/spirv_run/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/test/spirv_run/mod.rs')
-rw-r--r--ptx/src/test/spirv_run/mod.rs27
1 files changed, 18 insertions, 9 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 14c3bc9..4c5f9b3 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -107,27 +107,33 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>(
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
assert!(errors.len() == 0);
- let (spirv, _) = translate::to_spirv(ast)?;
+ let notcuda_module = translate::to_spirv_module(ast)?;
let name = CString::new(name)?;
- let result =
- run_spirv(name.as_c_str(), &spirv, input, output).map_err(|err| DisplayError { err })?;
+ let result = run_spirv(name.as_c_str(), notcuda_module, input, output)
+ .map_err(|err| DisplayError { err })?;
assert_eq!(output, result.as_slice());
Ok(())
}
fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
name: &CStr,
- spirv: &[u32],
+ module: translate::Module,
input: &[T],
output: &mut [T],
) -> ze::Result<Vec<T>> {
ze::init()?;
+ let spirv = module.spirv.assemble();
let byte_il = unsafe {
slice::from_raw_parts::<u8>(
spirv.as_ptr() as *const _,
spirv.len() * mem::size_of::<u32>(),
)
};
+ let use_shared_mem = module
+ .kernel_info
+ .get(name.to_str().unwrap())
+ .unwrap()
+ .uses_shared_mem;
let mut result = vec![0u8.into(); output.len()];
{
let mut drivers = ze::Driver::get()?;
@@ -140,7 +146,7 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
let module = match module {
Ok(m) => m,
Err(err) => {
- let raw_err_string = log.get_cstring()?;
+ let raw_err_string = log.get_cstring()?;
let err_string = raw_err_string.to_string_lossy();
panic!("{:?}\n{}", err, err_string);
}
@@ -164,6 +170,9 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
kernel.set_group_size(1, 1, 1)?;
kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
kernel.set_arg_buffer(1, out_b_ptr_mut)?;
+ if use_shared_mem {
+ unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
+ }
cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?;
cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?;
queue.execute(cmd_list)?;
@@ -179,7 +188,7 @@ fn test_spvtxt_assert<'a>(
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
assert!(errors.len() == 0);
- let (ptx_mod, _) = translate::to_spirv_module(ast)?;
+ let spirv_module = translate::to_spirv_module(ast)?;
let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut());
@@ -211,9 +220,9 @@ fn test_spvtxt_assert<'a>(
rspirv::binary::parse_words(&parsed_spirv, &mut loader)?;
let spvtxt_mod = loader.module();
unsafe { spirv_tools::spvBinaryDestroy(spv_binary) };
- if !is_spirv_fn_equal(&ptx_mod.functions[0], &spvtxt_mod.functions[0]) {
+ if !is_spirv_fn_equal(&spirv_module.spirv.functions[0], &spvtxt_mod.functions[0]) {
// We could simply use ptx_mod.disassemble, but SPIRV-Tools text formattinmg is so much nicer
- let spv_from_ptx_binary = ptx_mod.assemble();
+ let spv_from_ptx_binary = spirv_module.spirv.assemble();
let mut spv_text: spirv_tools::spv_text = ptr::null_mut();
let result = unsafe {
spirv_tools::spvBinaryToText(
@@ -234,7 +243,7 @@ fn test_spvtxt_assert<'a>(
// TODO: stop leaking kernel text
Cow::Borrowed(spv_from_ptx_text)
} else {
- Cow::Owned(ptx_mod.disassemble())
+ Cow::Owned(spirv_module.spirv.disassemble())
};
if let Ok(dump_path) = env::var("NOTCUDA_TEST_SPIRV_DUMP_DIR") {
let mut path = PathBuf::from(dump_path);