aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/test/spirv_run/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/test/spirv_run/mod.rs')
-rw-r--r--ptx/src/test/spirv_run/mod.rs106
1 files changed, 86 insertions, 20 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 6c073f3..512b6cf 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -1,18 +1,6 @@
use crate::ptx;
use crate::translate;
use hip_runtime_sys::hipError_t;
-use hip_runtime_sys::hipGetDeviceProperties;
-use hip_runtime_sys::hipInit;
-use hip_runtime_sys::hipMalloc;
-use hip_runtime_sys::hipMemcpyAsync;
-use hip_runtime_sys::hipMemcpyKind;
-use hip_runtime_sys::hipMemcpyWithStream;
-use hip_runtime_sys::hipMemset;
-use hip_runtime_sys::hipModuleGetFunction;
-use hip_runtime_sys::hipModuleLaunchKernel;
-use hip_runtime_sys::hipModuleLoadData;
-use hip_runtime_sys::hipStreamCreate;
-use hip_runtime_sys::hipStreamSynchronize;
use rspirv::{
binary::{Assemble, Disassemble},
dr::{Block, Function, Instruction, Loader, Operand},
@@ -46,7 +34,17 @@ macro_rules! test_ptx {
let ptx = include_str!(concat!(stringify!($fn_name), ".ptx"));
let input = $input;
let mut output = $output;
- test_ptx_assert(stringify!($fn_name), ptx, &input, &mut output)
+ test_hip_assert(stringify!($fn_name), ptx, &input, &mut output)
+ }
+ }
+
+ paste::item! {
+ #[test]
+ fn [<$fn_name _cuda>]() -> Result<(), Box<dyn std::error::Error>> {
+ let ptx = include_str!(concat!(stringify!($fn_name), ".ptx"));
+ let input = $input;
+ let mut output = $output;
+ test_cuda_assert(stringify!($fn_name), ptx, &input, &mut output)
}
}
@@ -75,7 +73,7 @@ macro_rules! test_ptx {
}
test_ptx!(ld_st, [1u64], [1u64]);
-test_ptx!(ld_st_implicit, [0.5f32], [0.5f32]);
+test_ptx!(ld_st_implicit, [0.5f32, 0.25f32], [0.5f32]);
test_ptx!(mov, [1u64], [1u64]);
test_ptx!(mul_lo, [1u64], [2u64]);
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
@@ -99,7 +97,8 @@ test_ptx!(ntid, [3u32], [4u32]);
test_ptx!(reg_local, [12u64], [13u64]);
test_ptx!(mov_address, [0xDEADu64], [0u64]);
test_ptx!(b64tof64, [111u64], [111u64]);
-test_ptx!(implicit_param, [34u32], [34u32]);
+// This segfaults NV compiler
+// test_ptx!(implicit_param, [34u32], [34u32]);
test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
test_ptx!(
@@ -178,8 +177,6 @@ test_ptx!(stateful_ld_st_ntid, [123u64], [123u64]);
test_ptx!(stateful_ld_st_ntid_chain, [12651u64], [12651u64]);
test_ptx!(stateful_ld_st_ntid_sub, [96311u64], [96311u64]);
test_ptx!(shared_ptr_take_address, [97815231u64], [97815231u64]);
-// For now, we just make sure that it builds and links
-test_ptx!(assertfail, [716523871u64], [716523872u64]);
test_ptx!(cvt_s64_s32, [-1i32], [-1i64]);
test_ptx!(add_tuning, [2u64], [3u64]);
test_ptx!(add_non_coherent, [3u64], [4u64]);
@@ -224,6 +221,7 @@ test_ptx!(membar, [152731u32], [152731u32]);
test_ptx!(shared_unify_extern, [7681u64], [15362u64]);
test_ptx!(shared_unify_private, [67153u64], [134306u64]);
+test_ptx!(assertfail);
test_ptx!(func_ptr);
test_ptx!(lanemask_lt);
test_ptx!(extern_func);
@@ -246,7 +244,7 @@ impl<T: Debug> Debug for DisplayError<T> {
impl<T: Debug> error::Error for DisplayError<T> {}
-fn test_ptx_assert<
+fn test_hip_assert<
'a,
Input: From<u8> + Debug + Copy + PartialEq,
Output: From<u8> + Debug + Copy + PartialEq + Default,
@@ -261,12 +259,29 @@ fn test_ptx_assert<
assert!(errors.len() == 0);
let zluda_module = translate::to_spirv_module(ast)?;
let name = CString::new(name)?;
- let result = run_spirv(name.as_c_str(), zluda_module, input, output)
+ let result = run_hip(name.as_c_str(), zluda_module, input, output)
.map_err(|err| DisplayError { err })?;
assert_eq!(result.as_slice(), output);
Ok(())
}
+fn test_cuda_assert<
+ 'a,
+ Input: From<u8> + Debug + Copy + PartialEq,
+ Output: From<u8> + Debug + Copy + PartialEq + Default,
+>(
+ name: &str,
+ ptx_text: &'a str,
+ input: &[Input],
+ output: &mut [Output],
+) -> Result<(), Box<dyn error::Error + 'a>> {
+ let name = CString::new(name)?;
+ let result =
+ run_cuda(name.as_c_str(), ptx_text, input, output).map_err(|err| DisplayError { err })?;
+ assert_eq!(result.as_slice(), output);
+ Ok(())
+}
+
macro_rules! hip_call {
($expr:expr) => {
#[allow(unused_unsafe)]
@@ -279,12 +294,60 @@ macro_rules! hip_call {
};
}
-fn run_spirv<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Default>(
+macro_rules! cuda_call {
+ ($expr:expr) => {
+ #[allow(unused_unsafe)]
+ {
+ let err = unsafe { $expr };
+ if err != cuda_driver_sys::CUresult::CUDA_SUCCESS {
+ return Result::Err(err);
+ }
+ }
+ };
+}
+
+fn run_cuda<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Default>(
+ name: &CStr,
+ ptx_module: &str,
+ input: &[Input],
+ output: &mut [Output],
+) -> Result<Vec<Output>, cuda_driver_sys::CUresult> {
+ use cuda_driver_sys::*;
+ cuda_call! { cuInit(0) };
+ let ptx_module = CString::new(ptx_module).unwrap();
+ let mut result = vec![0u8.into(); output.len()];
+ {
+ let mut ctx = ptr::null_mut();
+ cuda_call! { cuCtxCreate_v2(&mut ctx, 0, 0) };
+ let mut module = ptr::null_mut();
+ cuda_call! { cuModuleLoadData(&mut module, ptx_module.as_ptr() as _) };
+ let mut kernel = ptr::null_mut();
+ cuda_call! { cuModuleGetFunction(&mut kernel, module, name.as_ptr()) };
+ let mut inp_b = unsafe { mem::zeroed() };
+ cuda_call! { cuMemAlloc_v2(&mut inp_b, input.len() * mem::size_of::<Input>()) };
+ let mut out_b = unsafe { mem::zeroed() };
+ cuda_call! { cuMemAlloc_v2(&mut out_b, output.len() * mem::size_of::<Output>()) };
+ cuda_call! { cuMemcpyHtoD_v2(inp_b, input.as_ptr() as _, input.len() * mem::size_of::<Input>()) };
+ cuda_call! { cuMemsetD8_v2(out_b, 0, output.len() * mem::size_of::<Output>()) };
+ let mut args = [&inp_b, &out_b];
+ cuda_call! { cuLaunchKernel(kernel, 1,1,1,1,1,1, 1024, 0 as _, args.as_mut_ptr() as _, ptr::null_mut()) };
+ cuda_call! { cuMemcpyDtoH_v2(result.as_mut_ptr() as _, out_b, output.len() * mem::size_of::<Output>()) };
+ cuda_call! { cuStreamSynchronize(0 as _) };
+ cuda_call! { cuMemFree_v2(inp_b) };
+ cuda_call! { cuMemFree_v2(out_b) };
+ cuda_call! { cuModuleUnload(module) };
+ cuda_call! { cuCtxDestroy_v2(ctx) };
+ }
+ Ok(result)
+}
+
+fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Default>(
name: &CStr,
module: translate::Module,
input: &[Input],
output: &mut [Output],
) -> Result<Vec<Output>, hipError_t> {
+ use hip_runtime_sys::*;
hip_call! { hipInit(0) };
let spirv = module.spirv.assemble();
let mut result = vec![0u8.into(); output.len()];
@@ -310,6 +373,9 @@ fn run_spirv<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + D
hip_call! { hipModuleLaunchKernel(kernel, 1,1,1,1,1,1, 1024, stream, args.as_mut_ptr() as _, ptr::null_mut()) };
hip_call! { hipMemcpyAsync(result.as_mut_ptr() as _, out_b, output.len() * mem::size_of::<Output>(), hipMemcpyKind::hipMemcpyDeviceToHost, stream) };
hip_call! { hipStreamSynchronize(stream) };
+ hip_call! { hipFree(inp_b) };
+ hip_call! { hipFree(out_b) };
+ hip_call! { hipModuleUnload(module) };
}
Ok(result)
}