diff options
Diffstat (limited to 'zluda/src/impl/context.rs')
-rw-r--r-- | zluda/src/impl/context.rs | 126 |
1 files changed, 92 insertions, 34 deletions
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 429338b..d1b3e7b 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -7,7 +7,7 @@ use cuda_types::*; use hip_runtime_sys::*; use rustc_hash::{FxHashMap, FxHashSet}; use std::ptr; -use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::atomic::AtomicU32; use std::sync::Mutex; use std::{cell::RefCell, ffi::c_void}; @@ -28,57 +28,104 @@ impl ZludaObject for ContextData { const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT; fn drop_with_result(&mut self, _: bool) -> Result<(), CUresult> { - let mutable = self - .mutable - .get_mut() - .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; - fold_cuda_errors(mutable.streams.iter().copied().map(|s| { - unsafe { LiveCheck::drop_box_with_result(s, true)? }; - Ok(()) - })) + self.with_inner_mut(|mutable| { + fold_cuda_errors( + mutable + .streams + .iter() + .copied() + .map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }), + ) + })? } } pub(crate) struct ContextData { - pub(crate) flags: AtomicU32, - is_primary: bool, - pub(crate) ref_count: AtomicU32, pub(crate) device: hipDevice_t, - pub(crate) mutable: Mutex<ContextDataMutable>, + pub(crate) variant: ContextVariant, +} + +pub(crate) enum ContextVariant { + NonPrimary(NonPrimaryContextData), + Primary(Mutex<PrimaryContextData>), +} + +pub(crate) struct PrimaryContextData { + pub(crate) ref_count: u32, + pub(crate) flags: u32, + pub(crate) mutable: ContextInnerMutable, +} + +pub(crate) struct NonPrimaryContextData { + flags: AtomicU32, + mutable: Mutex<ContextInnerMutable>, } impl ContextData { - pub(crate) fn new( - flags: u32, - device: hipDevice_t, - is_primary: bool, - initial_refcount: u32, - ) -> Result<Self, CUresult> { - Ok(ContextData { - flags: AtomicU32::new(flags), + pub(crate) fn new_non_primary(flags: u32, device: hipDevice_t) -> Self { + Self { + device, + variant: ContextVariant::NonPrimary(NonPrimaryContextData { + flags: AtomicU32::new(flags), + mutable: Mutex::new(ContextInnerMutable::new()), + }), + } + } + + pub(crate) fn new_primary(device: hipDevice_t) -> Self { + Self { device, - ref_count: AtomicU32::new(initial_refcount), - is_primary, - mutable: Mutex::new(ContextDataMutable::new()), + variant: ContextVariant::Primary(Mutex::new(PrimaryContextData { + ref_count: 0, + flags: 0, + mutable: ContextInnerMutable::new(), + })), + } + } + + pub(crate) fn with_inner_mut<T>( + &self, + fn_: impl FnOnce(&mut ContextInnerMutable) -> T, + ) -> Result<T, CUresult> { + Ok(match self.variant { + ContextVariant::Primary(ref mutex_over_primary_ctx_data) => { + let mut primary_ctx_data = mutex_over_primary_ctx_data + .lock() + .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; + fn_(&mut primary_ctx_data.mutable) + } + ContextVariant::NonPrimary(NonPrimaryContextData { ref mutable, .. }) => { + let mut ctx_data_mutable = + mutable.lock().map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; + fn_(&mut ctx_data_mutable) + } }) } } -pub(crate) struct ContextDataMutable { +pub(crate) struct ContextInnerMutable { pub(crate) streams: FxHashSet<*mut stream::Stream>, pub(crate) modules: FxHashSet<*mut module::Module>, // Field below is here to support CUDA Driver Dark API pub(crate) local_storage: FxHashMap<*mut c_void, LocalStorageValue>, } -impl ContextDataMutable { - fn new() -> Self { - ContextDataMutable { +impl ContextInnerMutable { + pub(crate) fn new() -> Self { + ContextInnerMutable { streams: FxHashSet::default(), modules: FxHashSet::default(), local_storage: FxHashMap::default(), } } + pub(crate) fn drop_with_result(&mut self) -> Result<(), CUresult> { + fold_cuda_errors( + self.streams + .iter() + .copied() + .map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }), + ) + } } pub(crate) struct LocalStorageValue { @@ -94,7 +141,7 @@ pub(crate) unsafe fn create( if pctx == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let context_box = Box::new(LiveCheck::new(ContextData::new(flags, dev, false, 1)?)); + let context_box = Box::new(LiveCheck::new(ContextData::new_non_primary(flags, dev))); let context_ptr = Box::into_raw(context_box); *pctx = context_ptr; push_context_stack(context_ptr) @@ -105,7 +152,7 @@ pub(crate) unsafe fn destroy(ctx: *mut Context) -> Result<(), CUresult> { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } let ctx_ref = LiveCheck::as_result(ctx)?; - if ctx_ref.is_primary { + if let ContextVariant::Primary { .. } = ctx_ref.variant { return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); } CONTEXT_STACK.with(|stack| { @@ -175,14 +222,25 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> Result<(), CUresult> Ok(()) } +pub(crate) fn set_flags(flags: u32) -> Result<(), CUresult> { + with_current(|ctx| match ctx.variant { + ContextVariant::NonPrimary(ref context) => { + context + .flags + .store(flags, std::sync::atomic::Ordering::SeqCst); + Ok(()) + } + // This looks stupid, but this is an actual CUDA behavior, + // see primary_context.rs test + ContextVariant::Primary(_) => Ok(()), + })? +} + pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> { if ctx == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); } - let ctx = LiveCheck::as_result(ctx)?; - if ctx.ref_count.load(Ordering::Acquire) == 0 { - return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); - } + //let ctx = LiveCheck::as_result(ctx)?; //TODO: query device for properties roughly matching CUDA API version *version = 3020; Ok(()) |