diff options
Diffstat (limited to 'zluda/src/impl/context.rs')
-rw-r--r-- | zluda/src/impl/context.rs | 374 |
1 files changed, 0 insertions, 374 deletions
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs deleted file mode 100644 index ed3f90c..0000000 --- a/zluda/src/impl/context.rs +++ /dev/null @@ -1,374 +0,0 @@ -use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck}; -use super::{transmute_lifetime_mut, CUresult, GlobalState}; -use crate::{cuda::CUcontext, cuda_impl}; -use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32}; -use std::{ - collections::HashSet, - mem::{self}, -}; - -thread_local! { - pub static CONTEXT_STACK: RefCell<Vec<*mut Context>> = RefCell::new(Vec::new()); -} - -pub type Context = LiveCheck<ContextData>; - -impl HasLivenessCookie for ContextData { - #[cfg(target_pointer_width = "64")] - const COOKIE: usize = 0x5f0119560b643ffb; - - #[cfg(target_pointer_width = "32")] - const COOKIE: usize = 0x0b643ffb; - - const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT; - - fn try_drop(&mut self) -> Result<(), CUresult> { - for stream in self.streams.iter() { - let stream = unsafe { &mut **stream }; - stream.context = ptr::null_mut(); - Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?; - } - Ok(()) - } -} - -enum ContextRefCount { - Primary, - NonPrimary(NonZeroU32), -} - -impl ContextRefCount { - fn new(is_primary: bool) -> Self { - if is_primary { - ContextRefCount::Primary - } else { - ContextRefCount::NonPrimary(unsafe { NonZeroU32::new_unchecked(1) }) - } - } - - fn incr(&mut self) -> Result<(), CUresult> { - match self { - ContextRefCount::Primary => Ok(()), - ContextRefCount::NonPrimary(c) => { - let (new_count, overflow) = c.get().overflowing_add(1); - if overflow { - Err(CUresult::CUDA_ERROR_INVALID_VALUE) - } else { - *c = unsafe { NonZeroU32::new_unchecked(new_count) }; - Ok(()) - } - } - } - } - - #[must_use] - fn decr(&mut self) -> bool { - match self { - ContextRefCount::Primary => false, - ContextRefCount::NonPrimary(c) => { - if c.get() == 1 { - return true; - } - *c = unsafe { NonZeroU32::new_unchecked(c.get() - 1) }; - false - } - } - } -} - -pub struct ContextData { - pub flags: AtomicU32, - // This pointer is null only for a moment when constructing primary context - pub device: *mut device::Device, - ref_count: ContextRefCount, - pub default_stream: StreamData, - pub streams: HashSet<*mut StreamData>, - // All the fields below are here to support internal CUDA driver API - pub cuda_manager: *mut cuda_impl::rt::ContextStateManager, - pub cuda_state: *mut cuda_impl::rt::ContextState, - pub cuda_dtor_cb: Option< - extern "system" fn( - CUcontext, - *mut cuda_impl::rt::ContextStateManager, - *mut cuda_impl::rt::ContextState, - ), - >, -} - -impl ContextData { - pub fn new( - flags: c_uint, - is_primary: bool, - dev: *mut device::Device, - ) -> Result<Self, CUresult> { - let default_stream = StreamData::new_unitialized()?; - Ok(ContextData { - flags: AtomicU32::new(flags), - device: dev, - ref_count: ContextRefCount::new(is_primary), - default_stream, - streams: HashSet::new(), - cuda_manager: ptr::null_mut(), - cuda_state: ptr::null_mut(), - cuda_dtor_cb: None, - }) - } -} - -impl Context { - pub fn late_init(&mut self) { - let ctx_data: &'static mut _ = { - let this = self.as_option_mut().unwrap(); - let result = { unsafe { transmute_lifetime_mut(this) } }; - drop(this); - result - }; - { self.as_option_mut().unwrap() } - .default_stream - .late_init(ctx_data); - } -} - -pub fn create_v2( - pctx: *mut *mut Context, - flags: u32, - dev_idx: device::Index, -) -> Result<(), CUresult> { - if pctx == ptr::null_mut() { - return Err(CUresult::CUDA_ERROR_INVALID_VALUE); - } - let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| { - let dev_ptr = dev as *mut _; - let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( - flags, - false, - dev_ptr as *mut _, - )?)); - ctx_box.late_init(); - Ok::<_, CUresult>(ctx_box) - })??; - let ctx_ref = ctx_box.as_mut() as *mut Context; - unsafe { *pctx = ctx_ref }; - mem::forget(ctx_box); - CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref)); - Ok(()) -} - -pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> { - if ctx == ptr::null_mut() { - return Err(CUresult::CUDA_ERROR_INVALID_VALUE); - } - CONTEXT_STACK.with(|stack| { - let mut stack = stack.borrow_mut(); - let should_pop = match stack.last() { - Some(active_ctx) => *active_ctx == (ctx as *mut _), - None => false, - }; - if should_pop { - stack.pop(); - } - }); - GlobalState::lock(|_| Context::destroy_impl(ctx))? -} - -pub(crate) fn push_current_v2(pctx: *mut Context) -> CUresult { - if pctx == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; - } - CONTEXT_STACK.with(|stack| stack.borrow_mut().push(pctx)); - CUresult::CUDA_SUCCESS -} - -pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult { - if pctx == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; - } - let mut ctx = CONTEXT_STACK.with(|stack| stack.borrow_mut().pop()); - let ctx_ptr = match &mut ctx { - Some(ctx) => *ctx as *mut _, - None => return CUresult::CUDA_ERROR_INVALID_CONTEXT, - }; - unsafe { *pctx = ctx_ptr }; - CUresult::CUDA_SUCCESS -} - -pub fn get_current(pctx: *mut *mut Context) -> Result<(), CUresult> { - if pctx == ptr::null_mut() { - return Err(CUresult::CUDA_ERROR_INVALID_VALUE); - } - let ctx = CONTEXT_STACK.with(|stack| match stack.borrow().last() { - Some(ctx) => *ctx as *mut _, - None => ptr::null_mut(), - }); - unsafe { *pctx = ctx }; - Ok(()) -} - -pub fn set_current(ctx: *mut Context) -> CUresult { - if ctx == ptr::null_mut() { - CONTEXT_STACK.with(|stack| stack.borrow_mut().pop()); - CUresult::CUDA_SUCCESS - } else { - CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx)); - CUresult::CUDA_SUCCESS - } -} - -pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> { - if ctx == ptr::null_mut() { - return Err(CUresult::CUDA_ERROR_INVALID_VALUE); - } - GlobalState::lock(|_| { - unsafe { &*ctx }.as_result()?; - Ok::<_, CUresult>(()) - })??; - //TODO: query device for properties roughly matching CUDA API version - unsafe { *version = 1100 }; - Ok(()) -} - -pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> { - let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?; - unsafe { *dev = dev_idx }; - Ok(()) -} - -pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> { - if pctx == ptr::null_mut() { - return Err(CUresult::CUDA_ERROR_INVALID_VALUE); - } - let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| { - let ctx = unchecked_ctx.as_result_mut()?; - ctx.ref_count.incr()?; - Ok::<_, CUresult>(unchecked_ctx as *mut _) - })??; - unsafe { *pctx = ctx }; - Ok(()) -} - -pub fn detach(pctx: *mut Context) -> Result<(), CUresult> { - if pctx == ptr::null_mut() { - return Err(CUresult::CUDA_ERROR_INVALID_VALUE); - } - GlobalState::lock_current_context_unchecked(|unchecked_ctx| { - let ctx = unchecked_ctx.as_result_mut()?; - if ctx.ref_count.decr() { - Context::destroy_impl(unchecked_ctx)?; - } - Ok::<_, CUresult>(()) - })? -} - -pub(crate) fn synchronize() -> Result<(), CUresult> { - GlobalState::lock_current_context(|ctx| { - ctx.default_stream.synchronize()?; - for stream in ctx.streams.iter().copied() { - unsafe { &mut *stream }.synchronize()?; - } - Ok(()) - })? -} - -#[cfg(test)] -mod test { - use super::super::test::CudaDriverFns; - use super::super::CUresult; - use std::{ffi::c_void, ptr}; - - cuda_driver_test!(destroy_leaves_zombie_context); - - fn destroy_leaves_zombie_context<T: CudaDriverFns>() { - assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); - let mut ctx1 = ptr::null_mut(); - let mut ctx2 = ptr::null_mut(); - let mut ctx3 = ptr::null_mut(); - assert_eq!(T::cuCtxCreate_v2(&mut ctx1, 0, 0), CUresult::CUDA_SUCCESS); - assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS); - assert_eq!(T::cuCtxCreate_v2(&mut ctx3, 0, 0), CUresult::CUDA_SUCCESS); - assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); - let mut popped_ctx1 = ptr::null_mut(); - assert_eq!( - T::cuCtxPopCurrent_v2(&mut popped_ctx1), - CUresult::CUDA_SUCCESS - ); - assert_eq!(popped_ctx1, ctx3); - let mut popped_ctx2 = ptr::null_mut(); - assert_eq!( - T::cuCtxPopCurrent_v2(&mut popped_ctx2), - CUresult::CUDA_SUCCESS - ); - assert_eq!(popped_ctx2, ctx2); - let mut popped_ctx3 = ptr::null_mut(); - assert_eq!( - T::cuCtxPopCurrent_v2(&mut popped_ctx3), - CUresult::CUDA_SUCCESS - ); - assert_eq!(popped_ctx3, ctx1); - let mut temp = 0; - assert_eq!( - T::cuCtxGetApiVersion(ctx2, &mut temp), - CUresult::CUDA_ERROR_INVALID_CONTEXT - ); - assert_eq!( - T::cuCtxPopCurrent_v2(&mut ptr::null_mut()), - CUresult::CUDA_ERROR_INVALID_CONTEXT - ); - } - - cuda_driver_test!(empty_pop_fails); - - fn empty_pop_fails<T: CudaDriverFns>() { - assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); - let mut ctx = ptr::null_mut(); - assert_eq!( - T::cuCtxPopCurrent_v2(&mut ctx), - CUresult::CUDA_ERROR_INVALID_CONTEXT - ); - } - - cuda_driver_test!(destroy_pops_top_of_stack); - - fn destroy_pops_top_of_stack<T: CudaDriverFns>() { - assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); - let mut ctx1 = ptr::null_mut(); - let mut ctx2 = ptr::null_mut(); - assert_eq!(T::cuCtxCreate_v2(&mut ctx1, 0, 0), CUresult::CUDA_SUCCESS); - assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS); - assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); - let mut popped_ctx1 = ptr::null_mut(); - assert_eq!( - T::cuCtxPopCurrent_v2(&mut popped_ctx1), - CUresult::CUDA_SUCCESS - ); - assert_eq!(popped_ctx1, ctx1); - let mut popped_ctx2 = ptr::null_mut(); - assert_eq!( - T::cuCtxPopCurrent_v2(&mut popped_ctx2), - CUresult::CUDA_ERROR_INVALID_CONTEXT - ); - } - - cuda_driver_test!(double_destroy_fails); - - fn double_destroy_fails<T: CudaDriverFns>() { - assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); - let mut ctx = ptr::null_mut(); - assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); - assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); - let destroy_result = T::cuCtxDestroy_v2(ctx); - // original CUDA impl returns randomly one or the other - assert!( - destroy_result == CUresult::CUDA_ERROR_INVALID_CONTEXT - || destroy_result == CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED - ); - } - - cuda_driver_test!(no_current_on_init); - - fn no_current_on_init<T: CudaDriverFns>() { - assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); - let mut ctx = 1 as *mut c_void; - assert_eq!(T::cuCtxGetCurrent(&mut ctx), CUresult::CUDA_SUCCESS); - assert_eq!(ctx, ptr::null_mut()); - } -} |