aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/tests/llama.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/tests/llama.rs')
-rw-r--r--zluda/tests/llama.rs84
1 files changed, 84 insertions, 0 deletions
diff --git a/zluda/tests/llama.rs b/zluda/tests/llama.rs
new file mode 100644
index 0000000..de73ac2
--- /dev/null
+++ b/zluda/tests/llama.rs
@@ -0,0 +1,84 @@
+use crate::common::CudaDriverFns;
+use cuda_types::*;
+use std::{ffi::c_void, mem, ptr};
+
+mod common;
+
+cuda_driver_test!(llama);
+
+unsafe fn llama<T: CudaDriverFns>(cuda: T) {
+ let kernel = concat!(include_str!("llama.ptx"), "\0");
+ assert_eq!(cuda.cuInit(0), CUresult::CUDA_SUCCESS);
+ let mut ctx = ptr::null_mut();
+ assert_eq!(
+ cuda.cuCtxCreate_v2(&mut ctx, 0, CUdevice_v1(0)),
+ CUresult::CUDA_SUCCESS
+ );
+ let mut module = ptr::null_mut();
+ assert_eq!(
+ cuda.cuModuleLoadData(&mut module, kernel.as_ptr() as _),
+ CUresult::CUDA_SUCCESS
+ );
+ let mut buffer_input = mem::zeroed();
+ assert_eq!(
+ cuda.cuMemAlloc_v2(&mut buffer_input, 4096),
+ CUresult::CUDA_SUCCESS
+ );
+ let mut host_buffer = include_bytes!("llama.bin").to_vec();
+ assert_eq!(
+ cuda.cuMemcpyHtoD_v2(buffer_input, host_buffer.as_ptr().cast(), host_buffer.len()),
+ CUresult::CUDA_SUCCESS
+ );
+ let mut buffer_output = mem::zeroed();
+ assert_eq!(
+ cuda.cuMemAlloc_v2(&mut buffer_output, 97 * mem::size_of::<f32>()),
+ CUresult::CUDA_SUCCESS
+ );
+ let mut kernel = mem::zeroed();
+ assert_eq!(
+ cuda.cuModuleGetFunction(
+ &mut kernel,
+ module,
+ b"_Z21dequantize_block_q6_KPKvPf\0".as_ptr() as _
+ ),
+ CUresult::CUDA_SUCCESS
+ );
+ let mut args = [
+ &mut buffer_input as *mut _ as *mut c_void,
+ &mut buffer_output as *mut _ as _,
+ ];
+ assert_eq!(
+ cuda.cuLaunchKernel(
+ kernel,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 0,
+ ptr::null_mut(),
+ &mut args as _,
+ ptr::null_mut()
+ ),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!(
+ cuda.cuStreamSynchronize(ptr::null_mut()),
+ CUresult::CUDA_SUCCESS
+ );
+ host_buffer.fill(0);
+ assert_eq!(
+ cuda.cuMemcpyDtoH_v2(
+ host_buffer.as_mut_ptr().cast(),
+ buffer_output,
+ host_buffer.len()
+ ),
+ CUresult::CUDA_SUCCESS
+ );
+ let host_buffer = host_buffer.align_to::<u32>().1;
+ assert_eq!(host_buffer[0], 0xBC6C7800);
+ assert_eq!(host_buffer[32], 0x3B260800);
+ assert_eq!(host_buffer[64], 0xBC301800);
+ assert_eq!(host_buffer[96], 0x3C0AFD00);
+}