diff options
Diffstat (limited to 'zluda/src/impl/mod.rs')
-rw-r--r-- | zluda/src/impl/mod.rs | 68 |
1 files changed, 62 insertions, 6 deletions
diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 03a68d8..8efd0a7 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -1,6 +1,8 @@ use cuda_types::*; use hip_runtime_sys::*; +pub(super) mod device; + #[cfg(debug_assertions)] pub(crate) fn unimplemented() -> CUresult { unimplemented!() @@ -11,16 +13,70 @@ pub(crate) fn unimplemented() -> CUresult { CUresult::ERROR_NOT_SUPPORTED } -pub(crate) trait FromCuda<T>: Sized { - fn from_cuda(t: T) -> Result<Self, CUerror>; +pub(crate) trait FromCuda<'a, T>: Sized { + fn from_cuda(t: &'a T) -> Result<Self, CUerror>; } -impl FromCuda<u32> for u32 { - fn from_cuda(x: u32) -> Result<Self, CUerror> { - Ok(x) - } +macro_rules! from_cuda_noop { + ($($type_:ty),*) => { + $( + impl<'a> FromCuda<'a, $type_> for $type_ { + fn from_cuda(x: &'a $type_) -> Result<Self, CUerror> { + Ok(*x) + } + } + + impl<'a> FromCuda<'a, *mut $type_> for &'a mut $type_ { + fn from_cuda(x: &'a *mut $type_) -> Result<Self, CUerror> { + match unsafe { x.as_mut() } { + Some(x) => Ok(x), + None => Err(CUerror::INVALID_VALUE), + } + } + } + )* + }; } +macro_rules! from_cuda_transmute { + ($($from:ty => $to:ty),*) => { + $( + impl<'a> FromCuda<'a, $from> for $to { + fn from_cuda(x: &'a $from) -> Result<Self, CUerror> { + Ok(unsafe { std::mem::transmute(*x) }) + } + } + + impl<'a> FromCuda<'a, *mut $from> for &'a mut $to { + fn from_cuda(x: &'a *mut $from) -> Result<Self, CUerror> { + match unsafe { x.cast::<$to>().as_mut() } { + Some(x) => Ok(x), + None => Err(CUerror::INVALID_VALUE), + } + } + } + + impl<'a> FromCuda<'a, *mut $from> for * mut $to { + fn from_cuda(x: &'a *mut $from) -> Result<Self, CUerror> { + Ok(x.cast::<$to>()) + } + } + )* + }; +} + +from_cuda_noop!( + *mut i8, + *mut usize, + i32, + u32, + cuda_types::CUdevprop, CUdevice_attribute +); +from_cuda_transmute!( + CUdevice => hipDevice_t, + CUuuid => hipUUID +); + pub(crate) fn init(flags: ::core::ffi::c_uint) -> hipError_t { unsafe { hipInit(flags) } } |