diff options
Diffstat (limited to 'zluda/src/impl/test.rs')
-rw-r--r-- | zluda/src/impl/test.rs | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/zluda/src/impl/test.rs b/zluda/src/impl/test.rs new file mode 100644 index 0000000..b36ccd8 --- /dev/null +++ b/zluda/src/impl/test.rs @@ -0,0 +1,157 @@ +#![allow(non_snake_case)] + +use crate::cuda as zluda; +use crate::cuda::CUstream; +use crate::cuda::CUuuid; +use crate::{ + cuda::{CUdevice, CUdeviceptr}, + r#impl::CUresult, +}; +use ::std::{ + ffi::c_void, + os::raw::{c_int, c_uint}, +}; +use cuda_driver_sys as cuda; + +#[macro_export] +macro_rules! cuda_driver_test { + ($func:ident) => { + paste! { + #[test] + fn [<$func _zluda>]() { + $func::<crate::r#impl::test::Zluda>() + } + + #[test] + fn [<$func _cuda>]() { + $func::<crate::r#impl::test::Cuda>() + } + } + }; +} + +pub trait CudaDriverFns { + fn cuInit(flags: c_uint) -> CUresult; + fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult; + fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult; + fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult; + fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult; + fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult; + fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult; + fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult; + fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult; + fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult; + fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult; + fn cuMemFree_v2(mem: *mut c_void) -> CUresult; + fn cuStreamDestroy_v2(stream: CUstream) -> CUresult; +} + +pub struct Zluda(); + +impl CudaDriverFns for Zluda { + fn cuInit(_flags: c_uint) -> CUresult { + zluda::cuInit(_flags as _) + } + + fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult { + zluda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev)) + } + + fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult { + zluda::cuCtxDestroy_v2(ctx as *mut _) + } + + fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult { + zluda::cuCtxPopCurrent_v2(pctx as *mut _) + } + + fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult { + zluda::cuCtxGetApiVersion(ctx as *mut _, version) + } + + fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult { + zluda::cuCtxGetCurrent(pctx as *mut _) + } + fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult { + zluda::cuMemAlloc_v2(dptr as *mut _, bytesize) + } + + fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult { + zluda::cuDeviceGetUuid(uuid, CUdevice(dev)) + } + + fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult { + zluda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active) + } + + fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { + zluda::cuStreamGetCtx(hStream, pctx as _) + } + + fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult { + zluda::cuStreamCreate(stream, flags) + } + + fn cuMemFree_v2(dptr: *mut c_void) -> CUresult { + zluda::cuMemFree_v2(CUdeviceptr(dptr as _)) + } + + fn cuStreamDestroy_v2(stream: CUstream) -> CUresult { + zluda::cuStreamDestroy_v2(stream) + } +} + +pub struct Cuda(); + +impl CudaDriverFns for Cuda { + fn cuInit(flags: c_uint) -> CUresult { + unsafe { CUresult(cuda::cuInit(flags) as c_uint) } + } + + fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult { + unsafe { CUresult(cuda::cuCtxCreate_v2(pctx as *mut _, flags, dev) as c_uint) } + } + + fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult { + unsafe { CUresult(cuda::cuCtxDestroy_v2(ctx as *mut _) as c_uint) } + } + + fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult { + unsafe { CUresult(cuda::cuCtxPopCurrent_v2(pctx as *mut _) as c_uint) } + } + + fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult { + unsafe { CUresult(cuda::cuCtxGetApiVersion(ctx as *mut _, version) as c_uint) } + } + + fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult { + unsafe { CUresult(cuda::cuCtxGetCurrent(pctx as *mut _) as c_uint) } + } + fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult { + unsafe { CUresult(cuda::cuMemAlloc_v2(dptr as *mut _, bytesize) as c_uint) } + } + + fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult { + unsafe { CUresult(cuda::cuDeviceGetUuid(uuid as *mut _, dev) as c_uint) } + } + + fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult { + unsafe { CUresult(cuda::cuDevicePrimaryCtxGetState(dev, flags, active) as c_uint) } + } + + fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { + unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) } + } + + fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult { + unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) } + } + + fn cuMemFree_v2(mem: *mut c_void) -> CUresult { + unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) } + } + + fn cuStreamDestroy_v2(stream: CUstream) -> CUresult { + unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) } + } +} |