summaryrefslogtreecommitdiffhomepage
path: root/notcuda/src/impl/context.rs
diff options
context:
space:
mode:
Diffstat (limited to 'notcuda/src/impl/context.rs')
-rw-r--r--notcuda/src/impl/context.rs191
1 files changed, 103 insertions, 88 deletions
diff --git a/notcuda/src/impl/context.rs b/notcuda/src/impl/context.rs
index 91d4460..9689ecf 100644
--- a/notcuda/src/impl/context.rs
+++ b/notcuda/src/impl/context.rs
@@ -1,18 +1,15 @@
-use super::CUresult;
-use super::{device, HasLivenessCookie, LiveCheck};
+use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck};
+use super::{CUresult, GlobalState};
use crate::{cuda::CUcontext, cuda_impl};
use l0::sys::ze_result_t;
-use std::mem::{self, ManuallyDrop};
+use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32};
use std::{
- cell::RefCell,
- num::NonZeroU32,
- os::raw::c_uint,
- ptr,
- sync::{atomic::AtomicU32, Mutex},
+ collections::HashSet,
+ mem::{self},
};
thread_local! {
- pub static CONTEXT_STACK: RefCell<Vec<*const Context>> = RefCell::new(Vec::new());
+ pub static CONTEXT_STACK: RefCell<Vec<*mut Context>> = RefCell::new(Vec::new());
}
pub type Context = LiveCheck<ContextData>;
@@ -23,6 +20,17 @@ impl HasLivenessCookie for ContextData {
#[cfg(target_pointer_width = "32")]
const COOKIE: usize = 0x0b643ffb;
+
+ const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT;
+
+ fn try_drop(&mut self) -> Result<(), CUresult> {
+ for stream in self.streams.iter() {
+ let stream = unsafe { &mut **stream };
+ stream.context = ptr::null_mut();
+ Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?;
+ }
+ Ok(())
+ }
}
enum ContextRefCount {
@@ -67,26 +75,16 @@ impl ContextRefCount {
}
}
}
-
- fn is_primary(&self) -> bool {
- match self {
- ContextRefCount::Primary => true,
- ContextRefCount::NonPrimary(_) => false,
- }
- }
}
pub struct ContextData {
pub flags: AtomicU32,
- pub device_index: device::Index,
// This pointer is null only for a moment when constructing primary context
- pub device: *const Mutex<device::Device>,
- // The split between mutable / non-mutable is mainly to avoid recursive locking in cuDevicePrimaryCtxGetState
- pub mutable: Mutex<ContextDataMutable>,
-}
-
-pub struct ContextDataMutable {
+ pub device: *mut device::Device,
ref_count: ContextRefCount,
+ pub default_stream: StreamData,
+ pub streams: HashSet<*mut StreamData>,
+ // All the fields below are here to support internal CUDA driver API
pub cuda_manager: *mut cuda_impl::rt::ContextStateManager,
pub cuda_state: *mut cuda_impl::rt::ContextState,
pub cuda_dtor_cb: Option<
@@ -100,63 +98,75 @@ pub struct ContextDataMutable {
impl ContextData {
pub fn new(
+ l0_ctx: &mut l0::Context,
+ l0_dev: &l0::Device,
flags: c_uint,
is_primary: bool,
- dev_index: device::Index,
- dev: *const Mutex<device::Device>,
- ) -> Self {
- ContextData {
+ dev: *mut device::Device,
+ ) -> Result<Self, CUresult> {
+ let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?;
+ Ok(ContextData {
flags: AtomicU32::new(flags),
- device_index: dev_index,
device: dev,
- mutable: Mutex::new(ContextDataMutable {
- ref_count: ContextRefCount::new(is_primary),
- cuda_manager: ptr::null_mut(),
- cuda_state: ptr::null_mut(),
- cuda_dtor_cb: None,
- }),
- }
+ ref_count: ContextRefCount::new(is_primary),
+ default_stream,
+ streams: HashSet::new(),
+ cuda_manager: ptr::null_mut(),
+ cuda_state: ptr::null_mut(),
+ cuda_dtor_cb: None,
+ })
}
}
-pub fn create_v2(pctx: *mut *mut Context, flags: u32, dev_idx: device::Index) -> CUresult {
+impl Context {
+ pub fn late_init(&mut self) {
+ let ctx_data = self.as_option_mut().unwrap();
+ ctx_data.default_stream.context = ctx_data as *mut _;
+ }
+}
+
+pub fn create_v2(
+ pctx: *mut *mut Context,
+ flags: u32,
+ dev_idx: device::Index,
+) -> Result<(), CUresult> {
if pctx == ptr::null_mut() {
- return CUresult::CUDA_ERROR_INVALID_VALUE;
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let dev = device::get_device_ref(dev_idx);
- let dev = match dev {
- Ok(d) => d,
- Err(e) => return e,
- };
- let mut ctx = Box::new(LiveCheck::new(ContextData::new(flags, false, dev_idx, dev)));
- let ctx_ref = ctx.as_mut() as *mut Context;
+ let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| {
+ let dev_ptr = dev as *mut _;
+ let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
+ &mut dev.l0_context,
+ &dev.base,
+ flags,
+ false,
+ dev_ptr as *mut _,
+ )?));
+ ctx_box.late_init();
+ Ok::<_, CUresult>(ctx_box)
+ })??;
+ let ctx_ref = ctx_box.as_mut() as *mut Context;
unsafe { *pctx = ctx_ref };
- mem::forget(ctx);
+ mem::forget(ctx_box);
CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref));
- CUresult::CUDA_SUCCESS
+ Ok(())
}
-pub fn destroy_v2(ctx: *mut Context) -> CUresult {
+pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> {
if ctx == ptr::null_mut() {
- return CUresult::CUDA_ERROR_INVALID_VALUE;
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
let should_pop = match stack.last() {
- Some(active_ctx) => *active_ctx == (ctx as *const _),
+ Some(active_ctx) => *active_ctx == (ctx as *mut _),
None => false,
};
if should_pop {
stack.pop();
}
});
- let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(ctx) });
- if !ctx_box.try_drop() {
- CUresult::CUDA_ERROR_INVALID_CONTEXT
- } else {
- unsafe { ManuallyDrop::drop(&mut ctx_box) };
- CUresult::CUDA_SUCCESS
- }
+ GlobalState::lock(|_| Context::destroy_impl(ctx))?
}
pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
@@ -172,17 +182,6 @@ pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
CUresult::CUDA_SUCCESS
}
-pub fn with_current<F: FnOnce(&ContextData) -> R, R>(f: F) -> Result<R, CUresult> {
- CONTEXT_STACK.with(|stack| {
- stack
- .borrow()
- .last()
- .and_then(|c| unsafe { &**c }.as_ref())
- .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
- .map(f)
- })
-}
-
pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> {
if pctx == ptr::null_mut() {
return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT);
@@ -205,37 +204,53 @@ pub fn set_current(ctx: *mut Context) -> CUresult {
}
}
-pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> CUresult {
- let _ctx = match unsafe { ctx.as_mut() } {
- None => return CUresult::CUDA_ERROR_INVALID_VALUE,
- Some(ctx) => match ctx.as_mut() {
- None => return CUresult::CUDA_ERROR_INVALID_CONTEXT,
- Some(ctx) => ctx,
- },
- };
+pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
+ if ctx == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ GlobalState::lock(|_| {
+ unsafe { &*ctx }.as_result()?;
+ Ok::<_, CUresult>(())
+ })??;
//TODO: query device for properties roughly matching CUDA API version
unsafe { *version = 1100 };
- CUresult::CUDA_SUCCESS
+ Ok(())
}
-pub fn get_device(dev: *mut device::Index) -> CUresult {
- let dev_idx = with_current(|ctx| ctx.device_index);
- match dev_idx {
- Ok(idx) => {
- unsafe { *dev = idx }
- CUresult::CUDA_SUCCESS
- }
- Err(err) => err,
+pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> {
+ let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?;
+ unsafe { *dev = dev_idx };
+ Ok(())
+}
+
+pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> {
+ if pctx == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
+ let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
+ let ctx = unchecked_ctx.as_result_mut()?;
+ ctx.ref_count.incr()?;
+ Ok::<_, CUresult>(unchecked_ctx as *mut _)
+ })??;
+ unsafe { *pctx = ctx };
+ Ok(())
}
-#[cfg(test)]
-pub fn is_context_stack_empty() -> bool {
- CONTEXT_STACK.with(|stack| stack.borrow().is_empty())
+pub fn detach(pctx: *mut Context) -> Result<(), CUresult> {
+ if pctx == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
+ let ctx = unchecked_ctx.as_result_mut()?;
+ if ctx.ref_count.decr() {
+ Context::destroy_impl(unchecked_ctx)?;
+ }
+ Ok::<_, CUresult>(())
+ })?
}
#[cfg(test)]
-mod tests {
+mod test {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
use std::{ffi::c_void, ptr};