aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/context.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src/impl/context.rs')
-rw-r--r--zluda/src/impl/context.rs126
1 files changed, 92 insertions, 34 deletions
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs
index 429338b..d1b3e7b 100644
--- a/zluda/src/impl/context.rs
+++ b/zluda/src/impl/context.rs
@@ -7,7 +7,7 @@ use cuda_types::*;
use hip_runtime_sys::*;
use rustc_hash::{FxHashMap, FxHashSet};
use std::ptr;
-use std::sync::atomic::{AtomicU32, Ordering};
+use std::sync::atomic::AtomicU32;
use std::sync::Mutex;
use std::{cell::RefCell, ffi::c_void};
@@ -28,57 +28,104 @@ impl ZludaObject for ContextData {
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT;
fn drop_with_result(&mut self, _: bool) -> Result<(), CUresult> {
- let mutable = self
- .mutable
- .get_mut()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- fold_cuda_errors(mutable.streams.iter().copied().map(|s| {
- unsafe { LiveCheck::drop_box_with_result(s, true)? };
- Ok(())
- }))
+ self.with_inner_mut(|mutable| {
+ fold_cuda_errors(
+ mutable
+ .streams
+ .iter()
+ .copied()
+ .map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }),
+ )
+ })?
}
}
pub(crate) struct ContextData {
- pub(crate) flags: AtomicU32,
- is_primary: bool,
- pub(crate) ref_count: AtomicU32,
pub(crate) device: hipDevice_t,
- pub(crate) mutable: Mutex<ContextDataMutable>,
+ pub(crate) variant: ContextVariant,
+}
+
+pub(crate) enum ContextVariant {
+ NonPrimary(NonPrimaryContextData),
+ Primary(Mutex<PrimaryContextData>),
+}
+
+pub(crate) struct PrimaryContextData {
+ pub(crate) ref_count: u32,
+ pub(crate) flags: u32,
+ pub(crate) mutable: ContextInnerMutable,
+}
+
+pub(crate) struct NonPrimaryContextData {
+ flags: AtomicU32,
+ mutable: Mutex<ContextInnerMutable>,
}
impl ContextData {
- pub(crate) fn new(
- flags: u32,
- device: hipDevice_t,
- is_primary: bool,
- initial_refcount: u32,
- ) -> Result<Self, CUresult> {
- Ok(ContextData {
- flags: AtomicU32::new(flags),
+ pub(crate) fn new_non_primary(flags: u32, device: hipDevice_t) -> Self {
+ Self {
+ device,
+ variant: ContextVariant::NonPrimary(NonPrimaryContextData {
+ flags: AtomicU32::new(flags),
+ mutable: Mutex::new(ContextInnerMutable::new()),
+ }),
+ }
+ }
+
+ pub(crate) fn new_primary(device: hipDevice_t) -> Self {
+ Self {
device,
- ref_count: AtomicU32::new(initial_refcount),
- is_primary,
- mutable: Mutex::new(ContextDataMutable::new()),
+ variant: ContextVariant::Primary(Mutex::new(PrimaryContextData {
+ ref_count: 0,
+ flags: 0,
+ mutable: ContextInnerMutable::new(),
+ })),
+ }
+ }
+
+ pub(crate) fn with_inner_mut<T>(
+ &self,
+ fn_: impl FnOnce(&mut ContextInnerMutable) -> T,
+ ) -> Result<T, CUresult> {
+ Ok(match self.variant {
+ ContextVariant::Primary(ref mutex_over_primary_ctx_data) => {
+ let mut primary_ctx_data = mutex_over_primary_ctx_data
+ .lock()
+ .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
+ fn_(&mut primary_ctx_data.mutable)
+ }
+ ContextVariant::NonPrimary(NonPrimaryContextData { ref mutable, .. }) => {
+ let mut ctx_data_mutable =
+ mutable.lock().map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
+ fn_(&mut ctx_data_mutable)
+ }
})
}
}
-pub(crate) struct ContextDataMutable {
+pub(crate) struct ContextInnerMutable {
pub(crate) streams: FxHashSet<*mut stream::Stream>,
pub(crate) modules: FxHashSet<*mut module::Module>,
// Field below is here to support CUDA Driver Dark API
pub(crate) local_storage: FxHashMap<*mut c_void, LocalStorageValue>,
}
-impl ContextDataMutable {
- fn new() -> Self {
- ContextDataMutable {
+impl ContextInnerMutable {
+ pub(crate) fn new() -> Self {
+ ContextInnerMutable {
streams: FxHashSet::default(),
modules: FxHashSet::default(),
local_storage: FxHashMap::default(),
}
}
+ pub(crate) fn drop_with_result(&mut self) -> Result<(), CUresult> {
+ fold_cuda_errors(
+ self.streams
+ .iter()
+ .copied()
+ .map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }),
+ )
+ }
}
pub(crate) struct LocalStorageValue {
@@ -94,7 +141,7 @@ pub(crate) unsafe fn create(
if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let context_box = Box::new(LiveCheck::new(ContextData::new(flags, dev, false, 1)?));
+ let context_box = Box::new(LiveCheck::new(ContextData::new_non_primary(flags, dev)));
let context_ptr = Box::into_raw(context_box);
*pctx = context_ptr;
push_context_stack(context_ptr)
@@ -105,7 +152,7 @@ pub(crate) unsafe fn destroy(ctx: *mut Context) -> Result<(), CUresult> {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let ctx_ref = LiveCheck::as_result(ctx)?;
- if ctx_ref.is_primary {
+ if let ContextVariant::Primary { .. } = ctx_ref.variant {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
}
CONTEXT_STACK.with(|stack| {
@@ -175,14 +222,25 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> Result<(), CUresult>
Ok(())
}
+pub(crate) fn set_flags(flags: u32) -> Result<(), CUresult> {
+ with_current(|ctx| match ctx.variant {
+ ContextVariant::NonPrimary(ref context) => {
+ context
+ .flags
+ .store(flags, std::sync::atomic::Ordering::SeqCst);
+ Ok(())
+ }
+ // This looks stupid, but this is an actual CUDA behavior,
+ // see primary_context.rs test
+ ContextVariant::Primary(_) => Ok(()),
+ })?
+}
+
pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
if ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
}
- let ctx = LiveCheck::as_result(ctx)?;
- if ctx.ref_count.load(Ordering::Acquire) == 0 {
- return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
- }
+ //let ctx = LiveCheck::as_result(ctx)?;
//TODO: query device for properties roughly matching CUDA API version
*version = 3020;
Ok(())