diff options
Diffstat (limited to 'notcuda/src/impl/mod.rs')
-rw-r--r-- | notcuda/src/impl/mod.rs | 161 |
1 files changed, 123 insertions, 38 deletions
diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs index 5a72ce4..770a32b 100644 --- a/notcuda/src/impl/mod.rs +++ b/notcuda/src/impl/mod.rs @@ -1,5 +1,15 @@ -use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st}; -use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex}; +use crate::{ + cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st}, + r#impl::device::Device, +}; +use std::{ + ffi::c_void, + mem::{self, ManuallyDrop}, + os::raw::c_int, + ptr, + sync::Mutex, + sync::TryLockError, +}; #[cfg(test)] #[macro_use] @@ -7,9 +17,9 @@ pub mod test; pub mod context; pub mod device; pub mod export_table; +pub mod function; pub mod memory; pub mod module; -pub mod function; pub mod stream; #[cfg(debug_assertions)] @@ -22,8 +32,11 @@ pub fn unimplemented() -> CUresult { CUresult::CUDA_ERROR_NOT_SUPPORTED } -pub trait HasLivenessCookie { +pub trait HasLivenessCookie: Sized { const COOKIE: usize; + const LIVENESS_FAIL: CUresult; + + fn try_drop(&mut self) -> Result<(), CUresult>; } // This struct is a best-effort check if wrapped value has been dropped, @@ -42,34 +55,55 @@ impl<T: HasLivenessCookie> LiveCheck<T> { } } + fn destroy_impl(this: *mut Self) -> Result<(), CUresult> { + let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) }); + ctx_box.try_drop()?; + unsafe { ManuallyDrop::drop(&mut ctx_box) }; + Ok(()) + } + + unsafe fn ptr_from_inner(this: *mut T) -> *mut Self { + let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>()); + outer_ptr as *mut Self + } + pub unsafe fn as_ref_unchecked(&self) -> &T { &self.data } - pub fn as_ref(&self) -> Option<&T> { + pub fn as_option_mut(&mut self) -> Option<&mut T> { if self.cookie == T::COOKIE { - Some(&self.data) + Some(&mut self.data) } else { None } } - pub fn as_mut(&mut self) -> Option<&mut T> { + pub fn as_result(&self) -> Result<&T, CUresult> { if self.cookie == T::COOKIE { - Some(&mut self.data) + Ok(&self.data) } else { - None + Err(T::LIVENESS_FAIL) + } + } + + pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> { + if self.cookie == T::COOKIE { + Ok(&mut self.data) + } else { + Err(T::LIVENESS_FAIL) } } #[must_use] - pub fn try_drop(&mut self) -> bool { + pub fn try_drop(&mut self) -> Result<(), CUresult> { if self.cookie == T::COOKIE { self.cookie = 0; + self.data.try_drop()?; unsafe { ManuallyDrop::drop(&mut self.data) }; - return true; + return Ok(()); } - false + Err(T::LIVENESS_FAIL) } } @@ -121,6 +155,12 @@ impl From<l0::sys::ze_result_t> for CUresult { } } +impl<T> From<TryLockError<T>> for CUresult { + fn from(_: TryLockError<T>) -> Self { + CUresult::CUDA_ERROR_ILLEGAL_STATE + } +} + pub trait Encuda { type To: Sized; fn encuda(self: Self) -> Self::To; @@ -157,58 +197,103 @@ impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1, } } -pub enum Error { - L0(l0::sys::ze_result_t), - Cuda(CUresult), -} - -impl Encuda for Error { - type To = CUresult; - fn encuda(self: Self) -> Self::To { - match self { - Error::L0(e) => e.into(), - Error::Cuda(e) => e, - } - } -} - lazy_static! { static ref GLOBAL_STATE: Mutex<Option<GlobalState>> = Mutex::new(None); } struct GlobalState { - driver: l0::Driver, + devices: Vec<Device>, } unsafe impl Send for GlobalState {} +impl GlobalState { + fn lock<T>(f: impl FnOnce(&mut GlobalState) -> T) -> Result<T, CUresult> { + let mut mutex = GLOBAL_STATE + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + let global_state = mutex.as_mut().ok_or(CUresult::CUDA_ERROR_ILLEGAL_STATE)?; + Ok(f(global_state)) + } + + fn lock_device<T>( + device::Index(dev_idx): device::Index, + f: impl FnOnce(&'static mut device::Device) -> T, + ) -> Result<T, CUresult> { + if dev_idx < 0 { + return Err(CUresult::CUDA_ERROR_INVALID_DEVICE); + } + Self::lock(|global_state| { + if dev_idx >= global_state.devices.len() as c_int { + Err(CUresult::CUDA_ERROR_INVALID_DEVICE) + } else { + Ok(f(unsafe { + transmute_lifetime_mut(&mut global_state.devices[dev_idx as usize]) + })) + } + })? + } + + fn lock_current_context<F: FnOnce(&mut context::ContextData) -> R, R>( + f: F, + ) -> Result<R, CUresult> { + Self::lock_current_context_unchecked(|ctx| Ok(f(ctx.as_result_mut()?)))? + } + + fn lock_current_context_unchecked<F: FnOnce(&mut context::Context) -> R, R>( + f: F, + ) -> Result<R, CUresult> { + context::CONTEXT_STACK.with(|stack| { + stack + .borrow_mut() + .last_mut() + .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) + .map(|ctx| GlobalState::lock(|_| f(unsafe { &mut **ctx })))? + }) + } + + fn lock_stream<T>( + stream: *mut stream::Stream, + f: impl FnOnce(&mut stream::StreamData) -> T, + ) -> Result<T, CUresult> { + if stream == ptr::null_mut() + || stream == stream::CU_STREAM_LEGACY + || stream == stream::CU_STREAM_PER_THREAD + { + Self::lock_current_context(|ctx| Ok(f(&mut ctx.default_stream)))? + } else { + Self::lock(|_| { + let stream = unsafe { &mut *stream }.as_result_mut()?; + Ok(f(stream)) + })? + } + } +} + // TODO: implement fn is_intel_gpu_driver(_: &l0::Driver) -> bool { true } -pub fn init() -> l0::Result<()> { +pub fn init() -> Result<(), CUresult> { let mut global_state = GLOBAL_STATE .lock() - .map_err(|_| l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?; + .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; if global_state.is_some() { return Ok(()); } l0::init()?; let drivers = l0::Driver::get()?; - let driver = match drivers.into_iter().find(is_intel_gpu_driver) { - None => return Err(l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), - Some(driver) => { - device::init(&driver)?; - driver - } + let devices = match drivers.into_iter().find(is_intel_gpu_driver) { + None => return Err(CUresult::CUDA_ERROR_UNKNOWN), + Some(driver) => device::init(&driver)?, }; - *global_state = Some(GlobalState { driver }); + *global_state = Some(GlobalState { devices }); drop(global_state); Ok(()) } -unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T { +unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T { mem::transmute(t) } |