diff options
Diffstat (limited to 'zluda/src/impl/context.rs')
-rw-r--r-- | zluda/src/impl/context.rs | 359 |
1 files changed, 359 insertions, 0 deletions
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs new file mode 100644 index 0000000..873fc47 --- /dev/null +++ b/zluda/src/impl/context.rs @@ -0,0 +1,359 @@ +use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck}; +use super::{CUresult, GlobalState}; +use crate::{cuda::CUcontext, cuda_impl}; +use l0::sys::ze_result_t; +use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32}; +use std::{ + collections::HashSet, + mem::{self}, +}; + +thread_local! { + pub static CONTEXT_STACK: RefCell<Vec<*mut Context>> = RefCell::new(Vec::new()); +} + +pub type Context = LiveCheck<ContextData>; + +impl HasLivenessCookie for ContextData { + #[cfg(target_pointer_width = "64")] + const COOKIE: usize = 0x5f0119560b643ffb; + + #[cfg(target_pointer_width = "32")] + const COOKIE: usize = 0x0b643ffb; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT; + + fn try_drop(&mut self) -> Result<(), CUresult> { + for stream in self.streams.iter() { + let stream = unsafe { &mut **stream }; + stream.context = ptr::null_mut(); + Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?; + } + Ok(()) + } +} + +enum ContextRefCount { + Primary, + NonPrimary(NonZeroU32), +} + +impl ContextRefCount { + fn new(is_primary: bool) -> Self { + if is_primary { + ContextRefCount::Primary + } else { + ContextRefCount::NonPrimary(unsafe { NonZeroU32::new_unchecked(1) }) + } + } + + fn incr(&mut self) -> Result<(), CUresult> { + match self { + ContextRefCount::Primary => Ok(()), + ContextRefCount::NonPrimary(c) => { + let (new_count, overflow) = c.get().overflowing_add(1); + if overflow { + Err(CUresult::CUDA_ERROR_INVALID_VALUE) + } else { + *c = unsafe { NonZeroU32::new_unchecked(new_count) }; + Ok(()) + } + } + } + } + + #[must_use] + fn decr(&mut self) -> bool { + match self { + ContextRefCount::Primary => false, + ContextRefCount::NonPrimary(c) => { + if c.get() == 1 { + return true; + } + *c = unsafe { NonZeroU32::new_unchecked(c.get() - 1) }; + false + } + } + } +} + +pub struct ContextData { + pub flags: AtomicU32, + // This pointer is null only for a moment when constructing primary context + pub device: *mut device::Device, + ref_count: ContextRefCount, + pub default_stream: StreamData, + pub streams: HashSet<*mut StreamData>, + // All the fields below are here to support internal CUDA driver API + pub cuda_manager: *mut cuda_impl::rt::ContextStateManager, + pub cuda_state: *mut cuda_impl::rt::ContextState, + pub cuda_dtor_cb: Option< + extern "C" fn( + CUcontext, + *mut cuda_impl::rt::ContextStateManager, + *mut cuda_impl::rt::ContextState, + ), + >, +} + +impl ContextData { + pub fn new( + l0_ctx: &mut l0::Context, + l0_dev: &l0::Device, + flags: c_uint, + is_primary: bool, + dev: *mut device::Device, + ) -> Result<Self, CUresult> { + let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?; + Ok(ContextData { + flags: AtomicU32::new(flags), + device: dev, + ref_count: ContextRefCount::new(is_primary), + default_stream, + streams: HashSet::new(), + cuda_manager: ptr::null_mut(), + cuda_state: ptr::null_mut(), + cuda_dtor_cb: None, + }) + } +} + +impl Context { + pub fn late_init(&mut self) { + let ctx_data = self.as_option_mut().unwrap(); + ctx_data.default_stream.context = ctx_data as *mut _; + } +} + +pub fn create_v2( + pctx: *mut *mut Context, + flags: u32, + dev_idx: device::Index, +) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| { + let dev_ptr = dev as *mut _; + let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( + &mut dev.l0_context, + &dev.base, + flags, + false, + dev_ptr as *mut _, + )?)); + ctx_box.late_init(); + Ok::<_, CUresult>(ctx_box) + })??; + let ctx_ref = ctx_box.as_mut() as *mut Context; + unsafe { *pctx = ctx_ref }; + mem::forget(ctx_box); + CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref)); + Ok(()) +} + +pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> { + if ctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + CONTEXT_STACK.with(|stack| { + let mut stack = stack.borrow_mut(); + let should_pop = match stack.last() { + Some(active_ctx) => *active_ctx == (ctx as *mut _), + None => false, + }; + if should_pop { + stack.pop(); + } + }); + GlobalState::lock(|_| Context::destroy_impl(ctx))? +} + +pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult { + if pctx == ptr::null_mut() { + return CUresult::CUDA_ERROR_INVALID_VALUE; + } + let mut ctx = CONTEXT_STACK.with(|stack| stack.borrow_mut().pop()); + let ctx_ptr = match &mut ctx { + Some(ctx) => *ctx as *mut _, + None => return CUresult::CUDA_ERROR_INVALID_CONTEXT, + }; + unsafe { *pctx = ctx_ptr }; + CUresult::CUDA_SUCCESS +} + +pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> { + if pctx == ptr::null_mut() { + return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT); + } + let ctx = CONTEXT_STACK.with(|stack| match stack.borrow().last() { + Some(ctx) => *ctx as *mut _, + None => ptr::null_mut(), + }); + unsafe { *pctx = ctx }; + Ok(()) +} + +pub fn set_current(ctx: *mut Context) -> CUresult { + if ctx == ptr::null_mut() { + CONTEXT_STACK.with(|stack| stack.borrow_mut().pop()); + CUresult::CUDA_SUCCESS + } else { + CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx)); + CUresult::CUDA_SUCCESS + } +} + +pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> { + if ctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| { + unsafe { &*ctx }.as_result()?; + Ok::<_, CUresult>(()) + })??; + //TODO: query device for properties roughly matching CUDA API version + unsafe { *version = 1100 }; + Ok(()) +} + +pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> { + let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?; + unsafe { *dev = dev_idx }; + Ok(()) +} + +pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| { + let ctx = unchecked_ctx.as_result_mut()?; + ctx.ref_count.incr()?; + Ok::<_, CUresult>(unchecked_ctx as *mut _) + })??; + unsafe { *pctx = ctx }; + Ok(()) +} + +pub fn detach(pctx: *mut Context) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock_current_context_unchecked(|unchecked_ctx| { + let ctx = unchecked_ctx.as_result_mut()?; + if ctx.ref_count.decr() { + Context::destroy_impl(unchecked_ctx)?; + } + Ok::<_, CUresult>(()) + })? +} + +pub(crate) fn synchronize() -> CUresult { + // TODO: change the implementation once we do async stream operations + CUresult::CUDA_SUCCESS +} + +#[cfg(test)] +mod test { + use super::super::test::CudaDriverFns; + use super::super::CUresult; + use std::{ffi::c_void, ptr}; + + cuda_driver_test!(destroy_leaves_zombie_context); + + fn destroy_leaves_zombie_context<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx1 = ptr::null_mut(); + let mut ctx2 = ptr::null_mut(); + let mut ctx3 = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx1, 0, 0), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxCreate_v2(&mut ctx3, 0, 0), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); + let mut popped_ctx1 = ptr::null_mut(); + assert_eq!( + T::cuCtxPopCurrent_v2(&mut popped_ctx1), + CUresult::CUDA_SUCCESS + ); + assert_eq!(popped_ctx1, ctx3); + let mut popped_ctx2 = ptr::null_mut(); + assert_eq!( + T::cuCtxPopCurrent_v2(&mut popped_ctx2), + CUresult::CUDA_SUCCESS + ); + assert_eq!(popped_ctx2, ctx2); + let mut popped_ctx3 = ptr::null_mut(); + assert_eq!( + T::cuCtxPopCurrent_v2(&mut popped_ctx3), + CUresult::CUDA_SUCCESS + ); + assert_eq!(popped_ctx3, ctx1); + let mut temp = 0; + assert_eq!( + T::cuCtxGetApiVersion(ctx2, &mut temp), + CUresult::CUDA_ERROR_INVALID_CONTEXT + ); + assert_eq!( + T::cuCtxPopCurrent_v2(&mut ptr::null_mut()), + CUresult::CUDA_ERROR_INVALID_CONTEXT + ); + } + + cuda_driver_test!(empty_pop_fails); + + fn empty_pop_fails<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!( + T::cuCtxPopCurrent_v2(&mut ctx), + CUresult::CUDA_ERROR_INVALID_CONTEXT + ); + } + + cuda_driver_test!(destroy_pops_top_of_stack); + + fn destroy_pops_top_of_stack<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx1 = ptr::null_mut(); + let mut ctx2 = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx1, 0, 0), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); + let mut popped_ctx1 = ptr::null_mut(); + assert_eq!( + T::cuCtxPopCurrent_v2(&mut popped_ctx1), + CUresult::CUDA_SUCCESS + ); + assert_eq!(popped_ctx1, ctx1); + let mut popped_ctx2 = ptr::null_mut(); + assert_eq!( + T::cuCtxPopCurrent_v2(&mut popped_ctx2), + CUresult::CUDA_ERROR_INVALID_CONTEXT + ); + } + + cuda_driver_test!(double_destroy_fails); + + fn double_destroy_fails<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + let destroy_result = T::cuCtxDestroy_v2(ctx); + // original CUDA impl returns randomly one or the other + assert!( + destroy_result == CUresult::CUDA_ERROR_INVALID_CONTEXT + || destroy_result == CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED + ); + } + + cuda_driver_test!(no_current_on_init); + + fn no_current_on_init<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = 1 as *mut c_void; + assert_eq!(T::cuCtxGetCurrent(&mut ctx), CUresult::CUDA_SUCCESS); + assert_eq!(ctx, ptr::null_mut()); + } +} |