diff options
Diffstat (limited to 'zluda/src')
-rw-r--r-- | zluda/src/cuda.rs | 5 | ||||
-rw-r--r-- | zluda/src/impl/context.rs | 126 | ||||
-rw-r--r-- | zluda/src/impl/dark_api.rs | 84 | ||||
-rw-r--r-- | zluda/src/impl/device.rs | 105 | ||||
-rw-r--r-- | zluda/src/impl/mod.rs | 4 | ||||
-rw-r--r-- | zluda/src/impl/module.rs | 20 | ||||
-rw-r--r-- | zluda/src/impl/stream.rs | 20 |
7 files changed, 204 insertions, 160 deletions
diff --git a/zluda/src/cuda.rs b/zluda/src/cuda.rs index 898d732..1d054c3 100644 --- a/zluda/src/cuda.rs +++ b/zluda/src/cuda.rs @@ -69,6 +69,7 @@ cuda_function_declarations!( cuCtxGetDevice,
cuCtxGetLimit,
cuCtxSetLimit,
+ cuCtxSetFlags,
cuCtxGetStreamPriorityRange,
cuCtxSynchronize,
cuCtxSetCacheConfig,
@@ -485,6 +486,10 @@ mod definitions { context::set_limit(limit, value)
}
+ pub(crate) unsafe fn cuCtxSetFlags(flags: u32) -> Result<(), CUresult> {
+ context::set_flags(flags)
+ }
+
pub(crate) unsafe fn cuCtxGetStreamPriorityRange(
leastPriority: *mut ::std::os::raw::c_int,
greatestPriority: *mut ::std::os::raw::c_int,
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 429338b..d1b3e7b 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -7,7 +7,7 @@ use cuda_types::*; use hip_runtime_sys::*; use rustc_hash::{FxHashMap, FxHashSet}; use std::ptr; -use std::sync::atomic::{AtomicU32, Ordering}; +use std::sync::atomic::AtomicU32; use std::sync::Mutex; use std::{cell::RefCell, ffi::c_void}; @@ -28,57 +28,104 @@ impl ZludaObject for ContextData { const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT; fn drop_with_result(&mut self, _: bool) -> Result<(), CUresult> { - let mutable = self - .mutable - .get_mut() - .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; - fold_cuda_errors(mutable.streams.iter().copied().map(|s| { - unsafe { LiveCheck::drop_box_with_result(s, true)? }; - Ok(()) - })) + self.with_inner_mut(|mutable| { + fold_cuda_errors( + mutable + .streams + .iter() + .copied() + .map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }), + ) + })? } } pub(crate) struct ContextData { - pub(crate) flags: AtomicU32, - is_primary: bool, - pub(crate) ref_count: AtomicU32, pub(crate) device: hipDevice_t, - pub(crate) mutable: Mutex<ContextDataMutable>, + pub(crate) variant: ContextVariant, +} + +pub(crate) enum ContextVariant { + NonPrimary(NonPrimaryContextData), + Primary(Mutex<PrimaryContextData>), +} + +pub(crate) struct PrimaryContextData { + pub(crate) ref_count: u32, + pub(crate) flags: u32, + pub(crate) mutable: ContextInnerMutable, +} + +pub(crate) struct NonPrimaryContextData { + flags: AtomicU32, + mutable: Mutex<ContextInnerMutable>, } impl ContextData { - pub(crate) fn new( - flags: u32, - device: hipDevice_t, - is_primary: bool, - initial_refcount: u32, - ) -> Result<Self, CUresult> { - Ok(ContextData { - flags: AtomicU32::new(flags), + pub(crate) fn new_non_primary(flags: u32, device: hipDevice_t) -> Self { + Self { + device, + variant: ContextVariant::NonPrimary(NonPrimaryContextData { + flags: AtomicU32::new(flags), + mutable: Mutex::new(ContextInnerMutable::new()), + }), + } + } + + pub(crate) fn new_primary(device: hipDevice_t) -> Self { + Self { device, - ref_count: AtomicU32::new(initial_refcount), - is_primary, - mutable: Mutex::new(ContextDataMutable::new()), + variant: ContextVariant::Primary(Mutex::new(PrimaryContextData { + ref_count: 0, + flags: 0, + mutable: ContextInnerMutable::new(), + })), + } + } + + pub(crate) fn with_inner_mut<T>( + &self, + fn_: impl FnOnce(&mut ContextInnerMutable) -> T, + ) -> Result<T, CUresult> { + Ok(match self.variant { + ContextVariant::Primary(ref mutex_over_primary_ctx_data) => { + let mut primary_ctx_data = mutex_over_primary_ctx_data + .lock() + .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; + fn_(&mut primary_ctx_data.mutable) + } + ContextVariant::NonPrimary(NonPrimaryContextData { ref mutable, .. }) => { + let mut ctx_data_mutable = + mutable.lock().map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; + fn_(&mut ctx_data_mutable) + } }) } } -pub(crate) struct ContextDataMutable { +pub(crate) struct ContextInnerMutable { pub(crate) streams: FxHashSet<*mut stream::Stream>, pub(crate) modules: FxHashSet<*mut module::Module>, // Field below is here to support CUDA Driver Dark API pub(crate) local_storage: FxHashMap<*mut c_void, LocalStorageValue>, } -impl ContextDataMutable { - fn new() -> Self { - ContextDataMutable { +impl ContextInnerMutable { + pub(crate) fn new() -> Self { + ContextInnerMutable { streams: FxHashSet::default(), modules: FxHashSet::default(), local_storage: FxHashMap::default(), } } + pub(crate) fn drop_with_result(&mut self) -> Result<(), CUresult> { + fold_cuda_errors( + self.streams + .iter() + .copied() + .map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }), + ) + } } pub(crate) struct LocalStorageValue { @@ -94,7 +141,7 @@ pub(crate) unsafe fn create( if pctx == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let context_box = Box::new(LiveCheck::new(ContextData::new(flags, dev, false, 1)?)); + let context_box = Box::new(LiveCheck::new(ContextData::new_non_primary(flags, dev))); let context_ptr = Box::into_raw(context_box); *pctx = context_ptr; push_context_stack(context_ptr) @@ -105,7 +152,7 @@ pub(crate) unsafe fn destroy(ctx: *mut Context) -> Result<(), CUresult> { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } let ctx_ref = LiveCheck::as_result(ctx)?; - if ctx_ref.is_primary { + if let ContextVariant::Primary { .. } = ctx_ref.variant { return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); } CONTEXT_STACK.with(|stack| { @@ -175,14 +222,25 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> Result<(), CUresult> Ok(()) } +pub(crate) fn set_flags(flags: u32) -> Result<(), CUresult> { + with_current(|ctx| match ctx.variant { + ContextVariant::NonPrimary(ref context) => { + context + .flags + .store(flags, std::sync::atomic::Ordering::SeqCst); + Ok(()) + } + // This looks stupid, but this is an actual CUDA behavior, + // see primary_context.rs test + ContextVariant::Primary(_) => Ok(()), + })? +} + pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> { if ctx == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); } - let ctx = LiveCheck::as_result(ctx)?; - if ctx.ref_count.load(Ordering::Acquire) == 0 { - return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); - } + //let ctx = LiveCheck::as_result(ctx)?; //TODO: query device for properties roughly matching CUDA API version *version = 3020; Ok(()) diff --git a/zluda/src/impl/dark_api.rs b/zluda/src/impl/dark_api.rs index c3f4fca..c3b596c 100644 --- a/zluda/src/impl/dark_api.rs +++ b/zluda/src/impl/dark_api.rs @@ -121,20 +121,27 @@ impl CudaDarkApi for CudaDarkApiZluda { value: *mut c_void, dtor_callback: Option<extern "system" fn(cuda_types::CUcontext, *mut c_void, *mut c_void)>, ) -> CUresult { - with_context_or_current(cu_ctx, |ctx| { - let mut ctx_mutable = ctx - .mutable - .lock() - .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; - ctx_mutable.local_storage.insert( - key, - LocalStorageValue { - value, - _dtor_callback: dtor_callback, - }, - ); - Ok(()) - }) + unsafe fn context_local_storage_insert_impl( + cu_ctx: cuda_types::CUcontext, + key: *mut c_void, + value: *mut c_void, + dtor_callback: Option< + extern "system" fn(cuda_types::CUcontext, *mut c_void, *mut c_void), + >, + ) -> Result<(), CUresult> { + with_context_or_current(cu_ctx, |ctx| { + ctx.with_inner_mut(|ctx_mutable| { + ctx_mutable.local_storage.insert( + key, + LocalStorageValue { + value, + _dtor_callback: dtor_callback, + }, + ); + }) + })? + } + context_local_storage_insert_impl(cu_ctx, key, value, dtor_callback).into_cuda() } // TODO @@ -143,29 +150,30 @@ impl CudaDarkApi for CudaDarkApiZluda { } unsafe extern "system" fn context_local_storage_get( - result: *mut *mut c_void, + cu_result: *mut *mut c_void, cu_ctx: cuda_types::CUcontext, key: *mut c_void, ) -> CUresult { - let mut cu_result = None; - let query_cu_result = with_context_or_current(cu_ctx, |ctx| { - let ctx_mutable = ctx - .mutable - .lock() - .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; - cu_result = ctx_mutable.local_storage.get(&key).map(|v| v.value); - Ok(()) - }); - if query_cu_result != CUresult::CUDA_SUCCESS { - query_cu_result - } else { - match cu_result { - Some(value) => { - *result = value; - CUresult::CUDA_SUCCESS - } - None => CUresult::CUDA_ERROR_INVALID_VALUE, + unsafe fn context_local_storage_get_impl( + cu_ctx: cuda_types::CUcontext, + key: *mut c_void, + ) -> Result<*mut c_void, CUresult> { + with_context_or_current(cu_ctx, |ctx| { + ctx.with_inner_mut(|ctx_mutable| { + ctx_mutable + .local_storage + .get(&key) + .map(|v| v.value) + .ok_or(CUresult::CUDA_ERROR_INVALID_VALUE) + })? + })? + } + match context_local_storage_get_impl(cu_ctx, key) { + Ok(result) => { + *cu_result = result; + CUresult::CUDA_SUCCESS } + Err(err) => err, } } @@ -386,14 +394,14 @@ impl CudaDarkApi for CudaDarkApiZluda { } } -unsafe fn with_context_or_current( +unsafe fn with_context_or_current<T>( ctx: CUcontext, - f: impl FnOnce(&context::ContextData) -> Result<(), CUresult>, -) -> CUresult { + fn_: impl FnOnce(&context::ContextData) -> T, +) -> Result<T, CUresult> { if ctx == ptr::null_mut() { - context::with_current(|c| f(c)).into_cuda() + context::with_current(|c| fn_(c)) } else { let ctx = FromCuda::from_cuda(ctx); - LiveCheck::as_result(ctx).map(f).into_cuda() + Ok(fn_(LiveCheck::as_result(ctx)?)) } } diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 59201e2..c7e8190 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -1,6 +1,8 @@ +use super::context::{ContextInnerMutable, ContextVariant, PrimaryContextData}; use super::{ - context, LiveCheck, GLOBAL_STATE, + context, LiveCheck, GLOBAL_STATE }; +use crate::r#impl::context::ContextData; use crate::{r#impl::IntoCuda, hip_call_cuda}; use crate::hip_call; use cuda_types::{CUdevice_attribute, CUdevprop, CUuuid_st, CUresult}; @@ -10,11 +12,7 @@ use paste::paste; use std::{ mem, os::raw::{c_char, c_uint}, - ptr, - sync::{ - atomic::AtomicU32, - Mutex, - }, ops::AddAssign, ffi::CString, + ptr,ffi::CString, }; const ZLUDA_SUFFIX: &'static [u8] = b" [ZLUDA]\0"; @@ -28,9 +26,7 @@ pub const COMPUTE_CAPABILITY_MINOR: u32 = 8; pub(crate) struct Device { pub(crate) compilation_mode: CompilationMode, pub(crate) comgr_isa: CString, - // Primary context is lazy-initialized, the mutex is here to secure retain - // from multiple threads - primary_context: Mutex<Option<context::Context>>, + primary_context: context::Context, } impl Device { @@ -48,7 +44,7 @@ impl Device { Ok(Self { compilation_mode, comgr_isa, - primary_context: Mutex::new(None), + primary_context: LiveCheck::new(ContextData::new_primary(index as i32)), }) } } @@ -516,38 +512,29 @@ unsafe fn primary_ctx_get_or_retain( if pctx == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let ctx = primary_ctx(hip_dev, |ctx| { - let ctx = match ctx { - Some(ref mut ctx) => ctx, - None => { - ctx.insert(LiveCheck::new(context::ContextData::new(0, hip_dev, true, 0)?)) - }, - }; - if increment_refcount { - ctx.as_mut_unchecked().ref_count.get_mut().add_assign(1); + let ctx = primary_ctx(hip_dev, |ctx, raw_ctx| { + if increment_refcount || ctx.ref_count == 0 { + ctx.ref_count += 1; } - Ok(ctx as *mut _) + Ok(raw_ctx.cast_mut()) })??; *pctx = ctx; Ok(()) } pub(crate) unsafe fn primary_ctx_release(hip_dev: hipDevice_t) -> Result<(), CUresult> { - primary_ctx(hip_dev, move |maybe_ctx| { - if let Some(ctx) = maybe_ctx { - let ctx_data = ctx.as_mut_unchecked(); - let ref_count = ctx_data.ref_count.get_mut(); - *ref_count -= 1; - if *ref_count == 0 { - //TODO: fix - //ctx.try_drop(false) - Ok(()) - } else { - Ok(()) - } - } else { - Err(CUresult::CUDA_ERROR_INVALID_CONTEXT) + primary_ctx(hip_dev, |ctx, _| { + if ctx.ref_count == 0 { + return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); + } + ctx.ref_count -= 1; + if ctx.ref_count == 0 { + // Even if we encounter errors we can't really surface them + ctx.mutable.drop_with_result().ok(); + ctx.mutable = ContextInnerMutable::new(); + ctx.flags = 0; } + Ok(()) })? } @@ -566,53 +553,43 @@ pub(crate) unsafe fn primary_ctx_set_flags( hip_dev: hipDevice_t, flags: ::std::os::raw::c_uint, ) -> Result<(), CUresult> { - primary_ctx(hip_dev, move |maybe_ctx| { - if let Some(ctx) = maybe_ctx { - let ctx = ctx.as_mut_unchecked(); - ctx.flags = AtomicU32::new(flags); - Ok(()) - } else { - Err(CUresult::CUDA_ERROR_INVALID_CONTEXT) - } + primary_ctx(hip_dev, |ctx, _| { + ctx.flags = flags; + // TODO: actually use flags + Ok(()) })? } pub(crate) unsafe fn primary_ctx_get_state( hip_dev: hipDevice_t, - flags_ptr: *mut ::std::os::raw::c_uint, - active_ptr: *mut ::std::os::raw::c_int, + flags_ptr: *mut u32, + active_ptr: *mut i32, ) -> Result<(), CUresult> { if flags_ptr == ptr::null_mut() || active_ptr == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let maybe_flags = primary_ctx(hip_dev, move |maybe_ctx| { - if let Some(ctx) = maybe_ctx { - let ctx = ctx.as_mut_unchecked(); - Some(*ctx.flags.get_mut()) - } else { - None - } + let (flags, active) = primary_ctx(hip_dev, |ctx, _| { + (ctx.flags, (ctx.ref_count > 0) as i32) })?; - if let Some(flags) = maybe_flags { - *flags_ptr = flags; - *active_ptr = 1; - } else { - *flags_ptr = 0; - *active_ptr = 0; - } + *flags_ptr = flags; + *active_ptr = active; Ok(()) } pub(crate) unsafe fn primary_ctx<T>( dev: hipDevice_t, - f: impl FnOnce(&mut Option<context::Context>) -> T, + fn_: impl FnOnce(&mut PrimaryContextData, *const LiveCheck<ContextData>) -> T, ) -> Result<T, CUresult> { let device = GLOBAL_STATE.get()?.device(dev)?; - let mut maybe_primary_context = device - .primary_context - .lock() - .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; - Ok(f(&mut maybe_primary_context)) + let raw_ptr = &device.primary_context as *const _; + let context = device.primary_context.as_ref_unchecked(); + match context.variant { + ContextVariant::Primary(ref mutex_over_primary_ctx) => { + let mut primary_ctx = mutex_over_primary_ctx.lock().map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; + Ok(fn_(&mut primary_ctx, raw_ptr)) + }, + ContextVariant::NonPrimary(..) => Err(CUresult::CUDA_ERROR_UNKNOWN) + } } pub(crate) unsafe fn get_name(name: *mut i8, len: i32, device: i32) -> hipError_t { diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 88a95c4..34566af 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -148,6 +148,10 @@ impl<T: ZludaObject> LiveCheck<T> { outer_ptr as *mut Self } + pub unsafe fn as_ref_unchecked(&self) -> & T { + &self.data + } + pub unsafe fn as_mut_unchecked(&mut self) -> &mut T { &mut self.data } diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 6a6911a..8a49d43 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -31,13 +31,11 @@ impl ZludaObject for ModuleData { let deregistration_err = if !by_owner { if let Some(ctx) = self.owner { let ctx = unsafe { LiveCheck::as_result(ctx.as_ptr())? }; - let mut ctx_mutable = ctx - .mutable - .lock() - .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; - ctx_mutable - .modules - .remove(&unsafe { LiveCheck::from_raw(self) }); + ctx.with_inner_mut(|ctx_mutable| { + ctx_mutable + .modules + .remove(&unsafe { LiveCheck::from_raw(self) }); + })?; } Ok(()) } else { @@ -104,11 +102,9 @@ pub(crate) unsafe fn load_impl( isa, input, )?); - let mut ctx_mutable = ctx - .mutable - .lock() - .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; - ctx_mutable.modules.insert(module); + ctx.with_inner_mut(|ctx_mutable| { + ctx_mutable.modules.insert(module); + })?; *output = module; Ok(()) })? diff --git a/zluda/src/impl/stream.rs b/zluda/src/impl/stream.rs index fb53510..71ed20b 100644 --- a/zluda/src/impl/stream.rs +++ b/zluda/src/impl/stream.rs @@ -21,13 +21,11 @@ impl ZludaObject for StreamData { if !by_owner {
let ctx = unsafe { LiveCheck::as_result(self.ctx)? };
{
- let mut ctx_mutable = ctx
- .mutable
- .lock()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- ctx_mutable
- .streams
- .remove(&unsafe { LiveCheck::from_raw(&mut *self) });
+ ctx.with_inner_mut(|ctx_mutable| {
+ ctx_mutable
+ .streams
+ .remove(&unsafe { LiveCheck::from_raw(&mut *self) });
+ })?;
}
}
hip_call_cuda!(hipStreamDestroy(self.base));
@@ -59,11 +57,9 @@ pub(crate) unsafe fn create_with_priority( ctx: ptr::null_mut(),
})));
let ctx = context::with_current(|ctx| {
- let mut ctx_mutable = ctx
- .mutable
- .lock()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- ctx_mutable.streams.insert(stream);
+ ctx.with_inner_mut(|ctx_mutable| {
+ ctx_mutable.streams.insert(stream);
+ })?;
Ok(LiveCheck::from_raw(ctx as *const _ as _))
})??;
(*stream).as_mut_unchecked().ctx = ctx;
|