diff options
Diffstat (limited to 'notcuda/src/impl/stream.rs')
-rw-r--r-- | notcuda/src/impl/stream.rs | 203 |
1 files changed, 188 insertions, 15 deletions
diff --git a/notcuda/src/impl/stream.rs b/notcuda/src/impl/stream.rs index 1844677..e212dfc 100644 --- a/notcuda/src/impl/stream.rs +++ b/notcuda/src/impl/stream.rs @@ -1,36 +1,114 @@ -use std::cell::RefCell; +use super::{ + context::{Context, ContextData}, + CUresult, GlobalState, +}; +use std::{mem, ptr}; -use device::Device; +use super::{HasLivenessCookie, LiveCheck}; -use super::device; +pub type Stream = LiveCheck<StreamData>; -pub struct Stream { - dev: *mut Device, +pub const CU_STREAM_LEGACY: *mut Stream = 1 as *mut _; +pub const CU_STREAM_PER_THREAD: *mut Stream = 2 as *mut _; + +impl HasLivenessCookie for StreamData { + #[cfg(target_pointer_width = "64")] + const COOKIE: usize = 0x512097354de18d35; + + #[cfg(target_pointer_width = "32")] + const COOKIE: usize = 0x77d5cc0b; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE; + + fn try_drop(&mut self) -> Result<(), CUresult> { + if self.context != ptr::null_mut() { + let context = unsafe { &mut *self.context }; + if !context.streams.remove(&(self as *mut _)) { + return Err(CUresult::CUDA_ERROR_UNKNOWN); + } + } + Ok(()) + } +} + +pub struct StreamData { + pub context: *mut ContextData, + pub queue: l0::CommandQueue, } -pub struct DefaultStream { - streams: Vec<Option<Stream>>, +impl StreamData { + pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> { + Ok(StreamData { + context: ptr::null_mut(), + queue: l0::CommandQueue::new(ctx, dev)?, + }) + } + pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> { + let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context; + let l0_dev = &unsafe { &*ctx.device }.base; + Ok(StreamData { + context: ctx as *mut _, + queue: l0::CommandQueue::new(l0_ctx, l0_dev)?, + }) + } + + pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> { + let ctx = unsafe { &mut *self.context }; + let dev = unsafe { &mut *ctx.device }; + l0::CommandList::new(&mut dev.l0_context, &dev.base) + } } -impl DefaultStream { - fn new() -> Self { - DefaultStream { - streams: Vec::new(), +impl Drop for StreamData { + fn drop(&mut self) { + if self.context == ptr::null_mut() { + return; } + unsafe { (&mut *self.context).streams.remove(&(&mut *self as *mut _)) }; } } -thread_local! { - pub static DEFAULT_STREAM: RefCell<DefaultStream> = RefCell::new(DefaultStream::new()); +pub(crate) fn get_ctx(hstream: *mut Stream, pctx: *mut *mut Context) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + let ctx_ptr = GlobalState::lock_stream(hstream, |stream| stream.context)?; + if ctx_ptr == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED); + } + unsafe { *pctx = Context::ptr_from_inner(ctx_ptr) }; + Ok(()) +} + +pub(crate) fn create(phstream: *mut *mut Stream, _flags: u32) -> Result<(), CUresult> { + let stream_ptr = GlobalState::lock_current_context(|ctx| { + let mut stream_box = Box::new(Stream::new(StreamData::new(ctx)?)); + let stream_ptr = stream_box.as_mut().as_option_mut().unwrap() as *mut _; + if !ctx.streams.insert(stream_ptr) { + return Err(CUresult::CUDA_ERROR_UNKNOWN); + } + mem::forget(stream_box); + Ok::<_, CUresult>(stream_ptr) + })??; + unsafe { *phstream = Stream::ptr_from_inner(stream_ptr) }; + Ok(()) +} + +pub(crate) fn destroy_v2(pstream: *mut Stream) -> Result<(), CUresult> { + if pstream == ptr::null_mut() || pstream == CU_STREAM_LEGACY || pstream == CU_STREAM_PER_THREAD + { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| Stream::destroy_impl(pstream))? } #[cfg(test)] -mod tests { +mod test { use crate::cuda::CUstream; use super::super::test::CudaDriverFns; use super::super::CUresult; - use std::ptr; + use std::{ptr, thread}; const CU_STREAM_LEGACY: CUstream = 1 as *mut _; const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _; @@ -65,5 +143,100 @@ mod tests { CUresult::CUDA_SUCCESS ); assert_eq!(ctx2, stream_ctx2); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx1), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(stream_context_destroyed); + + fn stream_context_destroyed<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut stream = ptr::null_mut(); + assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS); + let mut stream_ctx1 = ptr::null_mut(); + assert_eq!( + T::cuStreamGetCtx(stream, &mut stream_ctx1), + CUresult::CUDA_SUCCESS + ); + assert_eq!(stream_ctx1, ctx); + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + let mut stream_ctx2 = ptr::null_mut(); + // When a context gets destroyed, its streams are also destroyed + let cuda_result = T::cuStreamGetCtx(stream, &mut stream_ctx2); + assert!( + cuda_result == CUresult::CUDA_ERROR_INVALID_HANDLE + || cuda_result == CUresult::CUDA_ERROR_INVALID_CONTEXT + || cuda_result == CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED + ); + assert_eq!( + T::cuStreamDestroy_v2(stream), + CUresult::CUDA_ERROR_INVALID_HANDLE + ); + // Check if creating another context is possible + let mut ctx2 = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(stream_moves_context_to_another_thread); + + fn stream_moves_context_to_another_thread<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut stream = ptr::null_mut(); + assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS); + let mut stream_ctx1 = ptr::null_mut(); + assert_eq!( + T::cuStreamGetCtx(stream, &mut stream_ctx1), + CUresult::CUDA_SUCCESS + ); + assert_eq!(stream_ctx1, ctx); + let stream_ptr = stream as usize; + let stream_ctx_on_thread = thread::spawn(move || { + let mut stream_ctx2 = ptr::null_mut(); + assert_eq!( + T::cuStreamGetCtx(stream_ptr as *mut _, &mut stream_ctx2), + CUresult::CUDA_SUCCESS + ); + stream_ctx2 as usize + }) + .join() + .unwrap(); + assert_eq!(stream_ctx1, stream_ctx_on_thread as *mut _); + // Cleanup + assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(can_destroy_stream); + + fn can_destroy_stream<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut stream = ptr::null_mut(); + assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(cant_destroy_default_stream); + + fn cant_destroy_default_stream<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + assert_ne!( + T::cuStreamDestroy_v2(super::CU_STREAM_LEGACY as *mut _), + CUresult::CUDA_SUCCESS + ); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); } } |