aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src/impl/mod.rs')
-rw-r--r--zluda/src/impl/mod.rs349
1 files changed, 164 insertions, 185 deletions
diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs
index 1335ef6..766b4a5 100644
--- a/zluda/src/impl/mod.rs
+++ b/zluda/src/impl/mod.rs
@@ -1,230 +1,209 @@
-use hip_runtime_sys::hipError_t;
-
-use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
-use std::{
- ffi::c_void,
- mem::{self, ManuallyDrop},
- os::raw::c_int,
- ptr,
- sync::Mutex,
- sync::TryLockError,
-};
-
-#[cfg(test)]
-#[macro_use]
-pub mod test;
-pub mod device;
-pub mod export_table;
-pub mod function;
-#[cfg_attr(windows, path = "os_win.rs")]
-#[cfg_attr(not(windows), path = "os_unix.rs")]
-pub(crate) mod os;
-pub(crate) mod module;
-pub(crate) mod context;
-pub(crate) mod memory;
-pub(crate) mod link;
-pub(crate) mod pointer;
+use cuda_types::*;
+use hip_runtime_sys::*;
+use std::mem::{self, ManuallyDrop, MaybeUninit};
+
+pub(super) mod context;
+pub(super) mod device;
+pub(super) mod driver;
+pub(super) mod function;
+pub(super) mod memory;
+pub(super) mod module;
+pub(super) mod pointer;
#[cfg(debug_assertions)]
-pub fn unimplemented() -> CUresult {
+pub(crate) fn unimplemented() -> CUresult {
unimplemented!()
}
#[cfg(not(debug_assertions))]
-pub fn unimplemented() -> CUresult {
- CUresult::CUDA_ERROR_NOT_SUPPORTED
+pub(crate) fn unimplemented() -> CUresult {
+ CUresult::ERROR_NOT_SUPPORTED
}
-#[macro_export]
-macro_rules! hip_call {
- ($expr:expr) => {
- #[allow(unused_unsafe)]
- {
- let err = unsafe { $expr };
- if err != hip_runtime_sys::hipError_t::hipSuccess {
- return Result::Err(err);
+pub(crate) trait FromCuda<'a, T>: Sized {
+ fn from_cuda(t: &'a T) -> Result<Self, CUerror>;
+}
+
+macro_rules! from_cuda_nop {
+ ($($type_:ty),*) => {
+ $(
+ impl<'a> FromCuda<'a, $type_> for $type_ {
+ fn from_cuda(x: &'a $type_) -> Result<Self, CUerror> {
+ Ok(*x)
+ }
}
- }
+
+ impl<'a> FromCuda<'a, *mut $type_> for &'a mut $type_ {
+ fn from_cuda(x: &'a *mut $type_) -> Result<Self, CUerror> {
+ match unsafe { x.as_mut() } {
+ Some(x) => Ok(x),
+ None => Err(CUerror::INVALID_VALUE),
+ }
+ }
+ }
+ )*
+ };
+}
+
+macro_rules! from_cuda_transmute {
+ ($($from:ty => $to:ty),*) => {
+ $(
+ impl<'a> FromCuda<'a, $from> for $to {
+ fn from_cuda(x: &'a $from) -> Result<Self, CUerror> {
+ Ok(unsafe { std::mem::transmute(*x) })
+ }
+ }
+
+ impl<'a> FromCuda<'a, *mut $from> for &'a mut $to {
+ fn from_cuda(x: &'a *mut $from) -> Result<Self, CUerror> {
+ match unsafe { x.cast::<$to>().as_mut() } {
+ Some(x) => Ok(x),
+ None => Err(CUerror::INVALID_VALUE),
+ }
+ }
+ }
+
+ impl<'a> FromCuda<'a, *mut $from> for * mut $to {
+ fn from_cuda(x: &'a *mut $from) -> Result<Self, CUerror> {
+ Ok(x.cast::<$to>())
+ }
+ }
+ )*
+ };
+}
+
+macro_rules! from_cuda_object {
+ ($($type_:ty),*) => {
+ $(
+ impl<'a> FromCuda<'a, <$type_ as ZludaObject>::CudaHandle> for <$type_ as ZludaObject>::CudaHandle {
+ fn from_cuda(handle: &'a <$type_ as ZludaObject>::CudaHandle) -> Result<<$type_ as ZludaObject>::CudaHandle, CUerror> {
+ Ok(*handle)
+ }
+ }
+
+ impl<'a> FromCuda<'a, *mut <$type_ as ZludaObject>::CudaHandle> for &'a mut <$type_ as ZludaObject>::CudaHandle {
+ fn from_cuda(handle: &'a *mut <$type_ as ZludaObject>::CudaHandle) -> Result<&'a mut <$type_ as ZludaObject>::CudaHandle, CUerror> {
+ match unsafe { handle.as_mut() } {
+ Some(x) => Ok(x),
+ None => Err(CUerror::INVALID_VALUE),
+ }
+ }
+ }
+
+ impl<'a> FromCuda<'a, <$type_ as ZludaObject>::CudaHandle> for &'a $type_ {
+ fn from_cuda(handle: &'a <$type_ as ZludaObject>::CudaHandle) -> Result<&'a $type_, CUerror> {
+ Ok(as_ref(handle).as_result()?)
+ }
+ }
+ )*
};
}
-pub trait HasLivenessCookie: Sized {
+from_cuda_nop!(
+ *mut i8,
+ *mut i32,
+ *mut usize,
+ *const ::core::ffi::c_void,
+ *const ::core::ffi::c_char,
+ *mut ::core::ffi::c_void,
+ *mut *mut ::core::ffi::c_void,
+ i32,
+ u32,
+ usize,
+ cuda_types::CUdevprop,
+ CUdevice_attribute
+);
+from_cuda_transmute!(
+ CUuuid => hipUUID,
+ CUfunction => hipFunction_t,
+ CUfunction_attribute => hipFunction_attribute,
+ CUstream => hipStream_t,
+ CUpointer_attribute => hipPointer_attribute,
+ CUdeviceptr_v2 => hipDeviceptr_t
+);
+from_cuda_object!(module::Module, context::Context);
+
+impl<'a> FromCuda<'a, CUlimit> for hipLimit_t {
+ fn from_cuda(limit: &'a CUlimit) -> Result<Self, CUerror> {
+ Ok(match *limit {
+ CUlimit::CU_LIMIT_STACK_SIZE => hipLimit_t::hipLimitStackSize,
+ CUlimit::CU_LIMIT_PRINTF_FIFO_SIZE => hipLimit_t::hipLimitPrintfFifoSize,
+ CUlimit::CU_LIMIT_MALLOC_HEAP_SIZE => hipLimit_t::hipLimitMallocHeapSize,
+ _ => return Err(CUerror::NOT_SUPPORTED),
+ })
+ }
+}
+
+pub(crate) trait ZludaObject: Sized + Send + Sync {
const COOKIE: usize;
- const LIVENESS_FAIL: CUresult;
+ const LIVENESS_FAIL: CUerror = cuda_types::CUerror::INVALID_VALUE;
- fn try_drop(&mut self) -> Result<(), CUresult>;
+ type CudaHandle: Sized;
+
+ fn drop_checked(&mut self) -> CUresult;
+
+ fn wrap(self) -> Self::CudaHandle {
+ unsafe { mem::transmute_copy(&LiveCheck::wrap(self)) }
+ }
}
-// This struct is a best-effort check if wrapped value has been dropped,
-// while it's inherently safe, its use coming from FFI is very unsafe
#[repr(C)]
-pub struct LiveCheck<T: HasLivenessCookie> {
+pub(crate) struct LiveCheck<T: ZludaObject> {
cookie: usize,
- data: ManuallyDrop<T>,
+ data: MaybeUninit<T>,
}
-impl<T: HasLivenessCookie> LiveCheck<T> {
- pub fn new(data: T) -> Self {
+impl<T: ZludaObject> LiveCheck<T> {
+ fn new(data: T) -> Self {
LiveCheck {
cookie: T::COOKIE,
- data: ManuallyDrop::new(data),
+ data: MaybeUninit::new(data),
}
}
- fn destroy_impl(this: *mut Self) -> Result<(), CUresult> {
- let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) });
- ctx_box.try_drop()?;
- unsafe { ManuallyDrop::drop(&mut ctx_box) };
- Ok(())
+ fn as_handle(&self) -> T::CudaHandle {
+ unsafe { mem::transmute_copy(&self) }
}
- unsafe fn ptr_from_inner(this: *mut T) -> *mut Self {
- let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>());
- outer_ptr as *mut Self
+ fn wrap(data: T) -> *mut Self {
+ Box::into_raw(Box::new(Self::new(data)))
}
- pub unsafe fn as_ref_unchecked(&self) -> &T {
- &self.data
- }
-
- pub fn as_option_mut(&mut self) -> Option<&mut T> {
+ fn as_result(&self) -> Result<&T, CUerror> {
if self.cookie == T::COOKIE {
- Some(&mut self.data)
- } else {
- None
- }
- }
-
- pub fn as_result(&self) -> Result<&T, CUresult> {
- if self.cookie == T::COOKIE {
- Ok(&self.data)
- } else {
- Err(T::LIVENESS_FAIL)
- }
- }
-
- pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> {
- if self.cookie == T::COOKIE {
- Ok(&mut self.data)
+ Ok(unsafe { self.data.assume_init_ref() })
} else {
Err(T::LIVENESS_FAIL)
}
}
+ // This looks like nonsense, but it's not. There are two cases:
+ // Err(CUerror) -> meaning that the object is invalid, this pointer does not point into valid memory
+ // Ok(maybe_error) -> meaning that the object is valid, we dropped everything, but there *might*
+ // an error in the underlying runtime that we want to propagate
#[must_use]
- pub fn try_drop(&mut self) -> Result<(), CUresult> {
+ fn drop_checked(&mut self) -> Result<Result<(), CUerror>, CUerror> {
if self.cookie == T::COOKIE {
self.cookie = 0;
- self.data.try_drop()?;
- unsafe { ManuallyDrop::drop(&mut self.data) };
- return Ok(());
- }
- Err(T::LIVENESS_FAIL)
- }
-}
-
-impl<T: HasLivenessCookie> Drop for LiveCheck<T> {
- fn drop(&mut self) {
- self.cookie = 0;
- }
-}
-
-pub trait CudaRepr: Sized {
- type Impl: Sized;
-}
-
-impl<T: CudaRepr> CudaRepr for *mut T {
- type Impl = *mut T::Impl;
-}
-
-pub trait Decuda<To> {
- fn decuda(self: Self) -> To;
-}
-
-impl<T: CudaRepr> Decuda<*mut T::Impl> for *mut T {
- fn decuda(self: Self) -> *mut T::Impl {
- self as *mut _
- }
-}
-
-impl<T> From<TryLockError<T>> for CUresult {
- fn from(_: TryLockError<T>) -> Self {
- CUresult::CUDA_ERROR_ILLEGAL_STATE
- }
-}
-
-impl From<ocl_core::Error> for CUresult {
- fn from(result: ocl_core::Error) -> Self {
- match result {
- _ => CUresult::CUDA_ERROR_UNKNOWN,
- }
- }
-}
-
-impl From<hip_runtime_sys::hipError_t> for CUresult {
- fn from(result: hip_runtime_sys::hipError_t) -> Self {
- match result {
- hip_runtime_sys::hipError_t::hipErrorRuntimeMemory
- | hip_runtime_sys::hipError_t::hipErrorRuntimeOther => CUresult::CUDA_ERROR_UNKNOWN,
- hip_runtime_sys::hipError_t(e) => CUresult(e),
- }
- }
-}
-
-pub trait Encuda {
- type To: Sized;
- fn encuda(self: Self) -> Self::To;
-}
-
-impl Encuda for CUresult {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- self
- }
-}
-
-impl Encuda for () {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- CUresult::CUDA_SUCCESS
- }
-}
-
-impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1, T2> {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- match self {
- Ok(e) => e.encuda(),
- Err(e) => e.encuda(),
+ let result = unsafe { self.data.assume_init_mut().drop_checked() };
+ unsafe { MaybeUninit::assume_init_drop(&mut self.data) };
+ Ok(result)
+ } else {
+ Err(T::LIVENESS_FAIL)
}
}
}
-impl Encuda for hipError_t {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- self.into()
- }
-}
-
-unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
- mem::transmute(t)
-}
-
-unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T {
- mem::transmute(t)
+pub fn as_ref<'a, T: ZludaObject>(
+ handle: &'a T::CudaHandle,
+) -> &'a ManuallyDrop<Box<LiveCheck<T>>> {
+ unsafe { mem::transmute(handle) }
}
-pub fn driver_get_version() -> c_int {
- i32::max_value()
-}
-
-impl<'a> CudaRepr for CUdeviceptr {
- type Impl = *mut c_void;
-}
-
-impl Decuda<*mut c_void> for CUdeviceptr {
- fn decuda(self) -> *mut c_void {
- self.0 as *mut _
- }
+pub fn drop_checked<T: ZludaObject>(handle: T::CudaHandle) -> Result<(), CUerror> {
+ let mut wrapped_object: ManuallyDrop<Box<LiveCheck<T>>> =
+ unsafe { mem::transmute_copy(&handle) };
+ let underlying_error = LiveCheck::drop_checked(&mut wrapped_object)?;
+ unsafe { ManuallyDrop::drop(&mut wrapped_object) };
+ underlying_error
}