diff options
Diffstat (limited to 'notcuda/src/impl/context.rs')
-rw-r--r-- | notcuda/src/impl/context.rs | 191 |
1 files changed, 103 insertions, 88 deletions
diff --git a/notcuda/src/impl/context.rs b/notcuda/src/impl/context.rs index 91d4460..9689ecf 100644 --- a/notcuda/src/impl/context.rs +++ b/notcuda/src/impl/context.rs @@ -1,18 +1,15 @@ -use super::CUresult; -use super::{device, HasLivenessCookie, LiveCheck}; +use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck}; +use super::{CUresult, GlobalState}; use crate::{cuda::CUcontext, cuda_impl}; use l0::sys::ze_result_t; -use std::mem::{self, ManuallyDrop}; +use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32}; use std::{ - cell::RefCell, - num::NonZeroU32, - os::raw::c_uint, - ptr, - sync::{atomic::AtomicU32, Mutex}, + collections::HashSet, + mem::{self}, }; thread_local! { - pub static CONTEXT_STACK: RefCell<Vec<*const Context>> = RefCell::new(Vec::new()); + pub static CONTEXT_STACK: RefCell<Vec<*mut Context>> = RefCell::new(Vec::new()); } pub type Context = LiveCheck<ContextData>; @@ -23,6 +20,17 @@ impl HasLivenessCookie for ContextData { #[cfg(target_pointer_width = "32")] const COOKIE: usize = 0x0b643ffb; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT; + + fn try_drop(&mut self) -> Result<(), CUresult> { + for stream in self.streams.iter() { + let stream = unsafe { &mut **stream }; + stream.context = ptr::null_mut(); + Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?; + } + Ok(()) + } } enum ContextRefCount { @@ -67,26 +75,16 @@ impl ContextRefCount { } } } - - fn is_primary(&self) -> bool { - match self { - ContextRefCount::Primary => true, - ContextRefCount::NonPrimary(_) => false, - } - } } pub struct ContextData { pub flags: AtomicU32, - pub device_index: device::Index, // This pointer is null only for a moment when constructing primary context - pub device: *const Mutex<device::Device>, - // The split between mutable / non-mutable is mainly to avoid recursive locking in cuDevicePrimaryCtxGetState - pub mutable: Mutex<ContextDataMutable>, -} - -pub struct ContextDataMutable { + pub device: *mut device::Device, ref_count: ContextRefCount, + pub default_stream: StreamData, + pub streams: HashSet<*mut StreamData>, + // All the fields below are here to support internal CUDA driver API pub cuda_manager: *mut cuda_impl::rt::ContextStateManager, pub cuda_state: *mut cuda_impl::rt::ContextState, pub cuda_dtor_cb: Option< @@ -100,63 +98,75 @@ pub struct ContextDataMutable { impl ContextData { pub fn new( + l0_ctx: &mut l0::Context, + l0_dev: &l0::Device, flags: c_uint, is_primary: bool, - dev_index: device::Index, - dev: *const Mutex<device::Device>, - ) -> Self { - ContextData { + dev: *mut device::Device, + ) -> Result<Self, CUresult> { + let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?; + Ok(ContextData { flags: AtomicU32::new(flags), - device_index: dev_index, device: dev, - mutable: Mutex::new(ContextDataMutable { - ref_count: ContextRefCount::new(is_primary), - cuda_manager: ptr::null_mut(), - cuda_state: ptr::null_mut(), - cuda_dtor_cb: None, - }), - } + ref_count: ContextRefCount::new(is_primary), + default_stream, + streams: HashSet::new(), + cuda_manager: ptr::null_mut(), + cuda_state: ptr::null_mut(), + cuda_dtor_cb: None, + }) } } -pub fn create_v2(pctx: *mut *mut Context, flags: u32, dev_idx: device::Index) -> CUresult { +impl Context { + pub fn late_init(&mut self) { + let ctx_data = self.as_option_mut().unwrap(); + ctx_data.default_stream.context = ctx_data as *mut _; + } +} + +pub fn create_v2( + pctx: *mut *mut Context, + flags: u32, + dev_idx: device::Index, +) -> Result<(), CUresult> { if pctx == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let dev = device::get_device_ref(dev_idx); - let dev = match dev { - Ok(d) => d, - Err(e) => return e, - }; - let mut ctx = Box::new(LiveCheck::new(ContextData::new(flags, false, dev_idx, dev))); - let ctx_ref = ctx.as_mut() as *mut Context; + let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| { + let dev_ptr = dev as *mut _; + let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( + &mut dev.l0_context, + &dev.base, + flags, + false, + dev_ptr as *mut _, + )?)); + ctx_box.late_init(); + Ok::<_, CUresult>(ctx_box) + })??; + let ctx_ref = ctx_box.as_mut() as *mut Context; unsafe { *pctx = ctx_ref }; - mem::forget(ctx); + mem::forget(ctx_box); CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref)); - CUresult::CUDA_SUCCESS + Ok(()) } -pub fn destroy_v2(ctx: *mut Context) -> CUresult { +pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> { if ctx == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } CONTEXT_STACK.with(|stack| { let mut stack = stack.borrow_mut(); let should_pop = match stack.last() { - Some(active_ctx) => *active_ctx == (ctx as *const _), + Some(active_ctx) => *active_ctx == (ctx as *mut _), None => false, }; if should_pop { stack.pop(); } }); - let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(ctx) }); - if !ctx_box.try_drop() { - CUresult::CUDA_ERROR_INVALID_CONTEXT - } else { - unsafe { ManuallyDrop::drop(&mut ctx_box) }; - CUresult::CUDA_SUCCESS - } + GlobalState::lock(|_| Context::destroy_impl(ctx))? } pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult { @@ -172,17 +182,6 @@ pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult { CUresult::CUDA_SUCCESS } -pub fn with_current<F: FnOnce(&ContextData) -> R, R>(f: F) -> Result<R, CUresult> { - CONTEXT_STACK.with(|stack| { - stack - .borrow() - .last() - .and_then(|c| unsafe { &**c }.as_ref()) - .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) - .map(f) - }) -} - pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> { if pctx == ptr::null_mut() { return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT); @@ -205,37 +204,53 @@ pub fn set_current(ctx: *mut Context) -> CUresult { } } -pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> CUresult { - let _ctx = match unsafe { ctx.as_mut() } { - None => return CUresult::CUDA_ERROR_INVALID_VALUE, - Some(ctx) => match ctx.as_mut() { - None => return CUresult::CUDA_ERROR_INVALID_CONTEXT, - Some(ctx) => ctx, - }, - }; +pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> { + if ctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| { + unsafe { &*ctx }.as_result()?; + Ok::<_, CUresult>(()) + })??; //TODO: query device for properties roughly matching CUDA API version unsafe { *version = 1100 }; - CUresult::CUDA_SUCCESS + Ok(()) } -pub fn get_device(dev: *mut device::Index) -> CUresult { - let dev_idx = with_current(|ctx| ctx.device_index); - match dev_idx { - Ok(idx) => { - unsafe { *dev = idx } - CUresult::CUDA_SUCCESS - } - Err(err) => err, +pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> { + let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?; + unsafe { *dev = dev_idx }; + Ok(()) +} + +pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } + let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| { + let ctx = unchecked_ctx.as_result_mut()?; + ctx.ref_count.incr()?; + Ok::<_, CUresult>(unchecked_ctx as *mut _) + })??; + unsafe { *pctx = ctx }; + Ok(()) } -#[cfg(test)] -pub fn is_context_stack_empty() -> bool { - CONTEXT_STACK.with(|stack| stack.borrow().is_empty()) +pub fn detach(pctx: *mut Context) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock_current_context_unchecked(|unchecked_ctx| { + let ctx = unchecked_ctx.as_result_mut()?; + if ctx.ref_count.decr() { + Context::destroy_impl(unchecked_ctx)?; + } + Ok::<_, CUresult>(()) + })? } #[cfg(test)] -mod tests { +mod test { use super::super::test::CudaDriverFns; use super::super::CUresult; use std::{ffi::c_void, ptr}; |