aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src')
-rw-r--r--zluda/src/cuda.rs5
-rw-r--r--zluda/src/impl/context.rs126
-rw-r--r--zluda/src/impl/dark_api.rs84
-rw-r--r--zluda/src/impl/device.rs105
-rw-r--r--zluda/src/impl/mod.rs4
-rw-r--r--zluda/src/impl/module.rs20
-rw-r--r--zluda/src/impl/stream.rs20
7 files changed, 204 insertions, 160 deletions
diff --git a/zluda/src/cuda.rs b/zluda/src/cuda.rs
index 898d732..1d054c3 100644
--- a/zluda/src/cuda.rs
+++ b/zluda/src/cuda.rs
@@ -69,6 +69,7 @@ cuda_function_declarations!(
cuCtxGetDevice,
cuCtxGetLimit,
cuCtxSetLimit,
+ cuCtxSetFlags,
cuCtxGetStreamPriorityRange,
cuCtxSynchronize,
cuCtxSetCacheConfig,
@@ -485,6 +486,10 @@ mod definitions {
context::set_limit(limit, value)
}
+ pub(crate) unsafe fn cuCtxSetFlags(flags: u32) -> Result<(), CUresult> {
+ context::set_flags(flags)
+ }
+
pub(crate) unsafe fn cuCtxGetStreamPriorityRange(
leastPriority: *mut ::std::os::raw::c_int,
greatestPriority: *mut ::std::os::raw::c_int,
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs
index 429338b..d1b3e7b 100644
--- a/zluda/src/impl/context.rs
+++ b/zluda/src/impl/context.rs
@@ -7,7 +7,7 @@ use cuda_types::*;
use hip_runtime_sys::*;
use rustc_hash::{FxHashMap, FxHashSet};
use std::ptr;
-use std::sync::atomic::{AtomicU32, Ordering};
+use std::sync::atomic::AtomicU32;
use std::sync::Mutex;
use std::{cell::RefCell, ffi::c_void};
@@ -28,57 +28,104 @@ impl ZludaObject for ContextData {
const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT;
fn drop_with_result(&mut self, _: bool) -> Result<(), CUresult> {
- let mutable = self
- .mutable
- .get_mut()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- fold_cuda_errors(mutable.streams.iter().copied().map(|s| {
- unsafe { LiveCheck::drop_box_with_result(s, true)? };
- Ok(())
- }))
+ self.with_inner_mut(|mutable| {
+ fold_cuda_errors(
+ mutable
+ .streams
+ .iter()
+ .copied()
+ .map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }),
+ )
+ })?
}
}
pub(crate) struct ContextData {
- pub(crate) flags: AtomicU32,
- is_primary: bool,
- pub(crate) ref_count: AtomicU32,
pub(crate) device: hipDevice_t,
- pub(crate) mutable: Mutex<ContextDataMutable>,
+ pub(crate) variant: ContextVariant,
+}
+
+pub(crate) enum ContextVariant {
+ NonPrimary(NonPrimaryContextData),
+ Primary(Mutex<PrimaryContextData>),
+}
+
+pub(crate) struct PrimaryContextData {
+ pub(crate) ref_count: u32,
+ pub(crate) flags: u32,
+ pub(crate) mutable: ContextInnerMutable,
+}
+
+pub(crate) struct NonPrimaryContextData {
+ flags: AtomicU32,
+ mutable: Mutex<ContextInnerMutable>,
}
impl ContextData {
- pub(crate) fn new(
- flags: u32,
- device: hipDevice_t,
- is_primary: bool,
- initial_refcount: u32,
- ) -> Result<Self, CUresult> {
- Ok(ContextData {
- flags: AtomicU32::new(flags),
+ pub(crate) fn new_non_primary(flags: u32, device: hipDevice_t) -> Self {
+ Self {
+ device,
+ variant: ContextVariant::NonPrimary(NonPrimaryContextData {
+ flags: AtomicU32::new(flags),
+ mutable: Mutex::new(ContextInnerMutable::new()),
+ }),
+ }
+ }
+
+ pub(crate) fn new_primary(device: hipDevice_t) -> Self {
+ Self {
device,
- ref_count: AtomicU32::new(initial_refcount),
- is_primary,
- mutable: Mutex::new(ContextDataMutable::new()),
+ variant: ContextVariant::Primary(Mutex::new(PrimaryContextData {
+ ref_count: 0,
+ flags: 0,
+ mutable: ContextInnerMutable::new(),
+ })),
+ }
+ }
+
+ pub(crate) fn with_inner_mut<T>(
+ &self,
+ fn_: impl FnOnce(&mut ContextInnerMutable) -> T,
+ ) -> Result<T, CUresult> {
+ Ok(match self.variant {
+ ContextVariant::Primary(ref mutex_over_primary_ctx_data) => {
+ let mut primary_ctx_data = mutex_over_primary_ctx_data
+ .lock()
+ .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
+ fn_(&mut primary_ctx_data.mutable)
+ }
+ ContextVariant::NonPrimary(NonPrimaryContextData { ref mutable, .. }) => {
+ let mut ctx_data_mutable =
+ mutable.lock().map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
+ fn_(&mut ctx_data_mutable)
+ }
})
}
}
-pub(crate) struct ContextDataMutable {
+pub(crate) struct ContextInnerMutable {
pub(crate) streams: FxHashSet<*mut stream::Stream>,
pub(crate) modules: FxHashSet<*mut module::Module>,
// Field below is here to support CUDA Driver Dark API
pub(crate) local_storage: FxHashMap<*mut c_void, LocalStorageValue>,
}
-impl ContextDataMutable {
- fn new() -> Self {
- ContextDataMutable {
+impl ContextInnerMutable {
+ pub(crate) fn new() -> Self {
+ ContextInnerMutable {
streams: FxHashSet::default(),
modules: FxHashSet::default(),
local_storage: FxHashMap::default(),
}
}
+ pub(crate) fn drop_with_result(&mut self) -> Result<(), CUresult> {
+ fold_cuda_errors(
+ self.streams
+ .iter()
+ .copied()
+ .map(|s| unsafe { LiveCheck::drop_box_with_result(s, true) }),
+ )
+ }
}
pub(crate) struct LocalStorageValue {
@@ -94,7 +141,7 @@ pub(crate) unsafe fn create(
if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let context_box = Box::new(LiveCheck::new(ContextData::new(flags, dev, false, 1)?));
+ let context_box = Box::new(LiveCheck::new(ContextData::new_non_primary(flags, dev)));
let context_ptr = Box::into_raw(context_box);
*pctx = context_ptr;
push_context_stack(context_ptr)
@@ -105,7 +152,7 @@ pub(crate) unsafe fn destroy(ctx: *mut Context) -> Result<(), CUresult> {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let ctx_ref = LiveCheck::as_result(ctx)?;
- if ctx_ref.is_primary {
+ if let ContextVariant::Primary { .. } = ctx_ref.variant {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
}
CONTEXT_STACK.with(|stack| {
@@ -175,14 +222,25 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> Result<(), CUresult>
Ok(())
}
+pub(crate) fn set_flags(flags: u32) -> Result<(), CUresult> {
+ with_current(|ctx| match ctx.variant {
+ ContextVariant::NonPrimary(ref context) => {
+ context
+ .flags
+ .store(flags, std::sync::atomic::Ordering::SeqCst);
+ Ok(())
+ }
+ // This looks stupid, but this is an actual CUDA behavior,
+ // see primary_context.rs test
+ ContextVariant::Primary(_) => Ok(()),
+ })?
+}
+
pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
if ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
}
- let ctx = LiveCheck::as_result(ctx)?;
- if ctx.ref_count.load(Ordering::Acquire) == 0 {
- return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
- }
+ //let ctx = LiveCheck::as_result(ctx)?;
//TODO: query device for properties roughly matching CUDA API version
*version = 3020;
Ok(())
diff --git a/zluda/src/impl/dark_api.rs b/zluda/src/impl/dark_api.rs
index c3f4fca..c3b596c 100644
--- a/zluda/src/impl/dark_api.rs
+++ b/zluda/src/impl/dark_api.rs
@@ -121,20 +121,27 @@ impl CudaDarkApi for CudaDarkApiZluda {
value: *mut c_void,
dtor_callback: Option<extern "system" fn(cuda_types::CUcontext, *mut c_void, *mut c_void)>,
) -> CUresult {
- with_context_or_current(cu_ctx, |ctx| {
- let mut ctx_mutable = ctx
- .mutable
- .lock()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- ctx_mutable.local_storage.insert(
- key,
- LocalStorageValue {
- value,
- _dtor_callback: dtor_callback,
- },
- );
- Ok(())
- })
+ unsafe fn context_local_storage_insert_impl(
+ cu_ctx: cuda_types::CUcontext,
+ key: *mut c_void,
+ value: *mut c_void,
+ dtor_callback: Option<
+ extern "system" fn(cuda_types::CUcontext, *mut c_void, *mut c_void),
+ >,
+ ) -> Result<(), CUresult> {
+ with_context_or_current(cu_ctx, |ctx| {
+ ctx.with_inner_mut(|ctx_mutable| {
+ ctx_mutable.local_storage.insert(
+ key,
+ LocalStorageValue {
+ value,
+ _dtor_callback: dtor_callback,
+ },
+ );
+ })
+ })?
+ }
+ context_local_storage_insert_impl(cu_ctx, key, value, dtor_callback).into_cuda()
}
// TODO
@@ -143,29 +150,30 @@ impl CudaDarkApi for CudaDarkApiZluda {
}
unsafe extern "system" fn context_local_storage_get(
- result: *mut *mut c_void,
+ cu_result: *mut *mut c_void,
cu_ctx: cuda_types::CUcontext,
key: *mut c_void,
) -> CUresult {
- let mut cu_result = None;
- let query_cu_result = with_context_or_current(cu_ctx, |ctx| {
- let ctx_mutable = ctx
- .mutable
- .lock()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- cu_result = ctx_mutable.local_storage.get(&key).map(|v| v.value);
- Ok(())
- });
- if query_cu_result != CUresult::CUDA_SUCCESS {
- query_cu_result
- } else {
- match cu_result {
- Some(value) => {
- *result = value;
- CUresult::CUDA_SUCCESS
- }
- None => CUresult::CUDA_ERROR_INVALID_VALUE,
+ unsafe fn context_local_storage_get_impl(
+ cu_ctx: cuda_types::CUcontext,
+ key: *mut c_void,
+ ) -> Result<*mut c_void, CUresult> {
+ with_context_or_current(cu_ctx, |ctx| {
+ ctx.with_inner_mut(|ctx_mutable| {
+ ctx_mutable
+ .local_storage
+ .get(&key)
+ .map(|v| v.value)
+ .ok_or(CUresult::CUDA_ERROR_INVALID_VALUE)
+ })?
+ })?
+ }
+ match context_local_storage_get_impl(cu_ctx, key) {
+ Ok(result) => {
+ *cu_result = result;
+ CUresult::CUDA_SUCCESS
}
+ Err(err) => err,
}
}
@@ -386,14 +394,14 @@ impl CudaDarkApi for CudaDarkApiZluda {
}
}
-unsafe fn with_context_or_current(
+unsafe fn with_context_or_current<T>(
ctx: CUcontext,
- f: impl FnOnce(&context::ContextData) -> Result<(), CUresult>,
-) -> CUresult {
+ fn_: impl FnOnce(&context::ContextData) -> T,
+) -> Result<T, CUresult> {
if ctx == ptr::null_mut() {
- context::with_current(|c| f(c)).into_cuda()
+ context::with_current(|c| fn_(c))
} else {
let ctx = FromCuda::from_cuda(ctx);
- LiveCheck::as_result(ctx).map(f).into_cuda()
+ Ok(fn_(LiveCheck::as_result(ctx)?))
}
}
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index 59201e2..c7e8190 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -1,6 +1,8 @@
+use super::context::{ContextInnerMutable, ContextVariant, PrimaryContextData};
use super::{
- context, LiveCheck, GLOBAL_STATE,
+ context, LiveCheck, GLOBAL_STATE
};
+use crate::r#impl::context::ContextData;
use crate::{r#impl::IntoCuda, hip_call_cuda};
use crate::hip_call;
use cuda_types::{CUdevice_attribute, CUdevprop, CUuuid_st, CUresult};
@@ -10,11 +12,7 @@ use paste::paste;
use std::{
mem,
os::raw::{c_char, c_uint},
- ptr,
- sync::{
- atomic::AtomicU32,
- Mutex,
- }, ops::AddAssign, ffi::CString,
+ ptr,ffi::CString,
};
const ZLUDA_SUFFIX: &'static [u8] = b" [ZLUDA]\0";
@@ -28,9 +26,7 @@ pub const COMPUTE_CAPABILITY_MINOR: u32 = 8;
pub(crate) struct Device {
pub(crate) compilation_mode: CompilationMode,
pub(crate) comgr_isa: CString,
- // Primary context is lazy-initialized, the mutex is here to secure retain
- // from multiple threads
- primary_context: Mutex<Option<context::Context>>,
+ primary_context: context::Context,
}
impl Device {
@@ -48,7 +44,7 @@ impl Device {
Ok(Self {
compilation_mode,
comgr_isa,
- primary_context: Mutex::new(None),
+ primary_context: LiveCheck::new(ContextData::new_primary(index as i32)),
})
}
}
@@ -516,38 +512,29 @@ unsafe fn primary_ctx_get_or_retain(
if pctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let ctx = primary_ctx(hip_dev, |ctx| {
- let ctx = match ctx {
- Some(ref mut ctx) => ctx,
- None => {
- ctx.insert(LiveCheck::new(context::ContextData::new(0, hip_dev, true, 0)?))
- },
- };
- if increment_refcount {
- ctx.as_mut_unchecked().ref_count.get_mut().add_assign(1);
+ let ctx = primary_ctx(hip_dev, |ctx, raw_ctx| {
+ if increment_refcount || ctx.ref_count == 0 {
+ ctx.ref_count += 1;
}
- Ok(ctx as *mut _)
+ Ok(raw_ctx.cast_mut())
})??;
*pctx = ctx;
Ok(())
}
pub(crate) unsafe fn primary_ctx_release(hip_dev: hipDevice_t) -> Result<(), CUresult> {
- primary_ctx(hip_dev, move |maybe_ctx| {
- if let Some(ctx) = maybe_ctx {
- let ctx_data = ctx.as_mut_unchecked();
- let ref_count = ctx_data.ref_count.get_mut();
- *ref_count -= 1;
- if *ref_count == 0 {
- //TODO: fix
- //ctx.try_drop(false)
- Ok(())
- } else {
- Ok(())
- }
- } else {
- Err(CUresult::CUDA_ERROR_INVALID_CONTEXT)
+ primary_ctx(hip_dev, |ctx, _| {
+ if ctx.ref_count == 0 {
+ return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
+ }
+ ctx.ref_count -= 1;
+ if ctx.ref_count == 0 {
+ // Even if we encounter errors we can't really surface them
+ ctx.mutable.drop_with_result().ok();
+ ctx.mutable = ContextInnerMutable::new();
+ ctx.flags = 0;
}
+ Ok(())
})?
}
@@ -566,53 +553,43 @@ pub(crate) unsafe fn primary_ctx_set_flags(
hip_dev: hipDevice_t,
flags: ::std::os::raw::c_uint,
) -> Result<(), CUresult> {
- primary_ctx(hip_dev, move |maybe_ctx| {
- if let Some(ctx) = maybe_ctx {
- let ctx = ctx.as_mut_unchecked();
- ctx.flags = AtomicU32::new(flags);
- Ok(())
- } else {
- Err(CUresult::CUDA_ERROR_INVALID_CONTEXT)
- }
+ primary_ctx(hip_dev, |ctx, _| {
+ ctx.flags = flags;
+ // TODO: actually use flags
+ Ok(())
})?
}
pub(crate) unsafe fn primary_ctx_get_state(
hip_dev: hipDevice_t,
- flags_ptr: *mut ::std::os::raw::c_uint,
- active_ptr: *mut ::std::os::raw::c_int,
+ flags_ptr: *mut u32,
+ active_ptr: *mut i32,
) -> Result<(), CUresult> {
if flags_ptr == ptr::null_mut() || active_ptr == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let maybe_flags = primary_ctx(hip_dev, move |maybe_ctx| {
- if let Some(ctx) = maybe_ctx {
- let ctx = ctx.as_mut_unchecked();
- Some(*ctx.flags.get_mut())
- } else {
- None
- }
+ let (flags, active) = primary_ctx(hip_dev, |ctx, _| {
+ (ctx.flags, (ctx.ref_count > 0) as i32)
})?;
- if let Some(flags) = maybe_flags {
- *flags_ptr = flags;
- *active_ptr = 1;
- } else {
- *flags_ptr = 0;
- *active_ptr = 0;
- }
+ *flags_ptr = flags;
+ *active_ptr = active;
Ok(())
}
pub(crate) unsafe fn primary_ctx<T>(
dev: hipDevice_t,
- f: impl FnOnce(&mut Option<context::Context>) -> T,
+ fn_: impl FnOnce(&mut PrimaryContextData, *const LiveCheck<ContextData>) -> T,
) -> Result<T, CUresult> {
let device = GLOBAL_STATE.get()?.device(dev)?;
- let mut maybe_primary_context = device
- .primary_context
- .lock()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- Ok(f(&mut maybe_primary_context))
+ let raw_ptr = &device.primary_context as *const _;
+ let context = device.primary_context.as_ref_unchecked();
+ match context.variant {
+ ContextVariant::Primary(ref mutex_over_primary_ctx) => {
+ let mut primary_ctx = mutex_over_primary_ctx.lock().map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
+ Ok(fn_(&mut primary_ctx, raw_ptr))
+ },
+ ContextVariant::NonPrimary(..) => Err(CUresult::CUDA_ERROR_UNKNOWN)
+ }
}
pub(crate) unsafe fn get_name(name: *mut i8, len: i32, device: i32) -> hipError_t {
diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs
index 88a95c4..34566af 100644
--- a/zluda/src/impl/mod.rs
+++ b/zluda/src/impl/mod.rs
@@ -148,6 +148,10 @@ impl<T: ZludaObject> LiveCheck<T> {
outer_ptr as *mut Self
}
+ pub unsafe fn as_ref_unchecked(&self) -> & T {
+ &self.data
+ }
+
pub unsafe fn as_mut_unchecked(&mut self) -> &mut T {
&mut self.data
}
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs
index 6a6911a..8a49d43 100644
--- a/zluda/src/impl/module.rs
+++ b/zluda/src/impl/module.rs
@@ -31,13 +31,11 @@ impl ZludaObject for ModuleData {
let deregistration_err = if !by_owner {
if let Some(ctx) = self.owner {
let ctx = unsafe { LiveCheck::as_result(ctx.as_ptr())? };
- let mut ctx_mutable = ctx
- .mutable
- .lock()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- ctx_mutable
- .modules
- .remove(&unsafe { LiveCheck::from_raw(self) });
+ ctx.with_inner_mut(|ctx_mutable| {
+ ctx_mutable
+ .modules
+ .remove(&unsafe { LiveCheck::from_raw(self) });
+ })?;
}
Ok(())
} else {
@@ -104,11 +102,9 @@ pub(crate) unsafe fn load_impl(
isa,
input,
)?);
- let mut ctx_mutable = ctx
- .mutable
- .lock()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- ctx_mutable.modules.insert(module);
+ ctx.with_inner_mut(|ctx_mutable| {
+ ctx_mutable.modules.insert(module);
+ })?;
*output = module;
Ok(())
})?
diff --git a/zluda/src/impl/stream.rs b/zluda/src/impl/stream.rs
index fb53510..71ed20b 100644
--- a/zluda/src/impl/stream.rs
+++ b/zluda/src/impl/stream.rs
@@ -21,13 +21,11 @@ impl ZludaObject for StreamData {
if !by_owner {
let ctx = unsafe { LiveCheck::as_result(self.ctx)? };
{
- let mut ctx_mutable = ctx
- .mutable
- .lock()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- ctx_mutable
- .streams
- .remove(&unsafe { LiveCheck::from_raw(&mut *self) });
+ ctx.with_inner_mut(|ctx_mutable| {
+ ctx_mutable
+ .streams
+ .remove(&unsafe { LiveCheck::from_raw(&mut *self) });
+ })?;
}
}
hip_call_cuda!(hipStreamDestroy(self.base));
@@ -59,11 +57,9 @@ pub(crate) unsafe fn create_with_priority(
ctx: ptr::null_mut(),
})));
let ctx = context::with_current(|ctx| {
- let mut ctx_mutable = ctx
- .mutable
- .lock()
- .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
- ctx_mutable.streams.insert(stream);
+ ctx.with_inner_mut(|ctx_mutable| {
+ ctx_mutable.streams.insert(stream);
+ })?;
Ok(LiveCheck::from_raw(ctx as *const _ as _))
})??;
(*stream).as_mut_unchecked().ctx = ctx;