summaryrefslogtreecommitdiffhomepage
path: root/notcuda/src/impl/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'notcuda/src/impl/mod.rs')
-rw-r--r--notcuda/src/impl/mod.rs161
1 files changed, 123 insertions, 38 deletions
diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs
index 5a72ce4..770a32b 100644
--- a/notcuda/src/impl/mod.rs
+++ b/notcuda/src/impl/mod.rs
@@ -1,5 +1,15 @@
-use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
-use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex};
+use crate::{
+ cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st},
+ r#impl::device::Device,
+};
+use std::{
+ ffi::c_void,
+ mem::{self, ManuallyDrop},
+ os::raw::c_int,
+ ptr,
+ sync::Mutex,
+ sync::TryLockError,
+};
#[cfg(test)]
#[macro_use]
@@ -7,9 +17,9 @@ pub mod test;
pub mod context;
pub mod device;
pub mod export_table;
+pub mod function;
pub mod memory;
pub mod module;
-pub mod function;
pub mod stream;
#[cfg(debug_assertions)]
@@ -22,8 +32,11 @@ pub fn unimplemented() -> CUresult {
CUresult::CUDA_ERROR_NOT_SUPPORTED
}
-pub trait HasLivenessCookie {
+pub trait HasLivenessCookie: Sized {
const COOKIE: usize;
+ const LIVENESS_FAIL: CUresult;
+
+ fn try_drop(&mut self) -> Result<(), CUresult>;
}
// This struct is a best-effort check if wrapped value has been dropped,
@@ -42,34 +55,55 @@ impl<T: HasLivenessCookie> LiveCheck<T> {
}
}
+ fn destroy_impl(this: *mut Self) -> Result<(), CUresult> {
+ let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) });
+ ctx_box.try_drop()?;
+ unsafe { ManuallyDrop::drop(&mut ctx_box) };
+ Ok(())
+ }
+
+ unsafe fn ptr_from_inner(this: *mut T) -> *mut Self {
+ let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>());
+ outer_ptr as *mut Self
+ }
+
pub unsafe fn as_ref_unchecked(&self) -> &T {
&self.data
}
- pub fn as_ref(&self) -> Option<&T> {
+ pub fn as_option_mut(&mut self) -> Option<&mut T> {
if self.cookie == T::COOKIE {
- Some(&self.data)
+ Some(&mut self.data)
} else {
None
}
}
- pub fn as_mut(&mut self) -> Option<&mut T> {
+ pub fn as_result(&self) -> Result<&T, CUresult> {
if self.cookie == T::COOKIE {
- Some(&mut self.data)
+ Ok(&self.data)
} else {
- None
+ Err(T::LIVENESS_FAIL)
+ }
+ }
+
+ pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> {
+ if self.cookie == T::COOKIE {
+ Ok(&mut self.data)
+ } else {
+ Err(T::LIVENESS_FAIL)
}
}
#[must_use]
- pub fn try_drop(&mut self) -> bool {
+ pub fn try_drop(&mut self) -> Result<(), CUresult> {
if self.cookie == T::COOKIE {
self.cookie = 0;
+ self.data.try_drop()?;
unsafe { ManuallyDrop::drop(&mut self.data) };
- return true;
+ return Ok(());
}
- false
+ Err(T::LIVENESS_FAIL)
}
}
@@ -121,6 +155,12 @@ impl From<l0::sys::ze_result_t> for CUresult {
}
}
+impl<T> From<TryLockError<T>> for CUresult {
+ fn from(_: TryLockError<T>) -> Self {
+ CUresult::CUDA_ERROR_ILLEGAL_STATE
+ }
+}
+
pub trait Encuda {
type To: Sized;
fn encuda(self: Self) -> Self::To;
@@ -157,58 +197,103 @@ impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1,
}
}
-pub enum Error {
- L0(l0::sys::ze_result_t),
- Cuda(CUresult),
-}
-
-impl Encuda for Error {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- match self {
- Error::L0(e) => e.into(),
- Error::Cuda(e) => e,
- }
- }
-}
-
lazy_static! {
static ref GLOBAL_STATE: Mutex<Option<GlobalState>> = Mutex::new(None);
}
struct GlobalState {
- driver: l0::Driver,
+ devices: Vec<Device>,
}
unsafe impl Send for GlobalState {}
+impl GlobalState {
+ fn lock<T>(f: impl FnOnce(&mut GlobalState) -> T) -> Result<T, CUresult> {
+ let mut mutex = GLOBAL_STATE
+ .lock()
+ .unwrap_or_else(|poison| poison.into_inner());
+ let global_state = mutex.as_mut().ok_or(CUresult::CUDA_ERROR_ILLEGAL_STATE)?;
+ Ok(f(global_state))
+ }
+
+ fn lock_device<T>(
+ device::Index(dev_idx): device::Index,
+ f: impl FnOnce(&'static mut device::Device) -> T,
+ ) -> Result<T, CUresult> {
+ if dev_idx < 0 {
+ return Err(CUresult::CUDA_ERROR_INVALID_DEVICE);
+ }
+ Self::lock(|global_state| {
+ if dev_idx >= global_state.devices.len() as c_int {
+ Err(CUresult::CUDA_ERROR_INVALID_DEVICE)
+ } else {
+ Ok(f(unsafe {
+ transmute_lifetime_mut(&mut global_state.devices[dev_idx as usize])
+ }))
+ }
+ })?
+ }
+
+ fn lock_current_context<F: FnOnce(&mut context::ContextData) -> R, R>(
+ f: F,
+ ) -> Result<R, CUresult> {
+ Self::lock_current_context_unchecked(|ctx| Ok(f(ctx.as_result_mut()?)))?
+ }
+
+ fn lock_current_context_unchecked<F: FnOnce(&mut context::Context) -> R, R>(
+ f: F,
+ ) -> Result<R, CUresult> {
+ context::CONTEXT_STACK.with(|stack| {
+ stack
+ .borrow_mut()
+ .last_mut()
+ .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
+ .map(|ctx| GlobalState::lock(|_| f(unsafe { &mut **ctx })))?
+ })
+ }
+
+ fn lock_stream<T>(
+ stream: *mut stream::Stream,
+ f: impl FnOnce(&mut stream::StreamData) -> T,
+ ) -> Result<T, CUresult> {
+ if stream == ptr::null_mut()
+ || stream == stream::CU_STREAM_LEGACY
+ || stream == stream::CU_STREAM_PER_THREAD
+ {
+ Self::lock_current_context(|ctx| Ok(f(&mut ctx.default_stream)))?
+ } else {
+ Self::lock(|_| {
+ let stream = unsafe { &mut *stream }.as_result_mut()?;
+ Ok(f(stream))
+ })?
+ }
+ }
+}
+
// TODO: implement
fn is_intel_gpu_driver(_: &l0::Driver) -> bool {
true
}
-pub fn init() -> l0::Result<()> {
+pub fn init() -> Result<(), CUresult> {
let mut global_state = GLOBAL_STATE
.lock()
- .map_err(|_| l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?;
+ .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
if global_state.is_some() {
return Ok(());
}
l0::init()?;
let drivers = l0::Driver::get()?;
- let driver = match drivers.into_iter().find(is_intel_gpu_driver) {
- None => return Err(l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN),
- Some(driver) => {
- device::init(&driver)?;
- driver
- }
+ let devices = match drivers.into_iter().find(is_intel_gpu_driver) {
+ None => return Err(CUresult::CUDA_ERROR_UNKNOWN),
+ Some(driver) => device::init(&driver)?,
};
- *global_state = Some(GlobalState { driver });
+ *global_state = Some(GlobalState { devices });
drop(global_state);
Ok(())
}
-unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
+unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T {
mem::transmute(t)
}