diff options
author | Andrzej Janik <[email protected]> | 2021-07-04 12:54:27 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2021-07-04 12:54:27 +0200 |
commit | b460e359ae5eb21ca75dec655daa671e559f6a45 (patch) | |
tree | b362e0efdeff572f75ab3bda02735c91c0092681 | |
parent | ad2059872a43315ecb5361db1ab2c74d55363fae (diff) | |
download | ZLUDA-b460e359ae5eb21ca75dec655daa671e559f6a45.tar.gz ZLUDA-b460e359ae5eb21ca75dec655daa671e559f6a45.zip |
First attempt at async host side
-rw-r--r-- | level_zero/src/ze.rs | 14 | ||||
-rw-r--r-- | zluda/src/cuda.rs | 105 | ||||
-rw-r--r-- | zluda/src/impl/context.rs | 11 | ||||
-rw-r--r-- | zluda/src/impl/device.rs | 106 | ||||
-rw-r--r-- | zluda/src/impl/memory.rs | 11 | ||||
-rw-r--r-- | zluda/src/impl/mod.rs | 30 | ||||
-rw-r--r-- | zluda/src/impl/stream.rs | 56 |
7 files changed, 294 insertions, 39 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 703f2ce..16a98a0 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -325,6 +325,11 @@ impl<'a> CommandQueue<'a> { ));
Ok(())
}
+
+ pub fn synchronize(&self, timeout_ns: u64) -> Result<()> {
+ check!(sys::zeCommandQueueSynchronize(self.as_ffi(), timeout_ns));
+ Ok(())
+ }
}
impl<'a> Drop for CommandQueue<'a> {
@@ -1097,6 +1102,15 @@ impl<'a> Event<'a> { Ok(unsafe { Self::from_ffi(result) })
}
+ pub fn is_ready(&self) -> Result<bool> {
+ let status = unsafe { sys::zeEventQueryStatus(self.as_ffi()) };
+ match status {
+ sys::ze_result_t::ZE_RESULT_SUCCESS => Ok(true),
+ sys::ze_result_t::ZE_RESULT_NOT_READY => Ok(false),
+ err => Err(err),
+ }
+ }
+
unsafe fn with_raw_slice<'x, T>(
events: &[&Event<'x>],
f: impl FnOnce(u32, *mut sys::ze_event_handle_t) -> T,
diff --git a/zluda/src/cuda.rs b/zluda/src/cuda.rs index 9e7cbff..1bf10fd 100644 --- a/zluda/src/cuda.rs +++ b/zluda/src/cuda.rs @@ -2186,7 +2186,7 @@ pub extern "system" fn cuGetErrorString( error: CUresult, pStr: *mut *const ::std::os::raw::c_char, ) -> CUresult { - r#impl::get_error_string(error, pStr).encuda() + r#impl::get_error_string(error, pStr).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2209,7 +2209,10 @@ pub extern "system" fn cuDriverGetVersion(driverVersion: *mut ::std::os::raw::c_ } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuDeviceGet(device: *mut CUdevice, ordinal: ::std::os::raw::c_int) -> CUresult { +pub extern "system" fn cuDeviceGet( + device: *mut CUdevice, + ordinal: ::std::os::raw::c_int, +) -> CUresult { r#impl::device::get(device.decuda(), ordinal).encuda() } @@ -2374,7 +2377,7 @@ pub extern "system" fn cuCtxGetFlags(flags: *mut ::std::os::raw::c_uint) -> CUre #[cfg_attr(not(test), no_mangle)] pub extern "system" fn cuCtxSynchronize() -> CUresult { - r#impl::context::synchronize() + r#impl::context::synchronize().encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2429,7 +2432,10 @@ pub extern "system" fn cuCtxResetPersistingL2Cache() -> CUresult { } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuCtxAttach(pctx: *mut CUcontext, flags: ::std::os::raw::c_uint) -> CUresult { +pub extern "system" fn cuCtxAttach( + pctx: *mut CUcontext, + flags: ::std::os::raw::c_uint, +) -> CUresult { r#impl::context::attach(pctx.decuda(), flags).encuda() } @@ -2667,7 +2673,10 @@ pub extern "system" fn cuDeviceGetPCIBusId( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuIpcGetEventHandle(pHandle: *mut CUipcEventHandle, event: CUevent) -> CUresult { +pub extern "system" fn cuIpcGetEventHandle( + pHandle: *mut CUipcEventHandle, + event: CUevent, +) -> CUresult { r#impl::unimplemented() } @@ -2680,7 +2689,10 @@ pub extern "system" fn cuIpcOpenEventHandle( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuIpcGetMemHandle(pHandle: *mut CUipcMemHandle, dptr: CUdeviceptr) -> CUresult { +pub extern "system" fn cuIpcGetMemHandle( + pHandle: *mut CUipcMemHandle, + dptr: CUdeviceptr, +) -> CUresult { r#impl::unimplemented() } @@ -2930,12 +2942,18 @@ pub extern "system" fn cuMemcpyAtoHAsync_v2( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuMemcpy2DAsync_v2(pCopy: *const CUDA_MEMCPY2D, hStream: CUstream) -> CUresult { +pub extern "system" fn cuMemcpy2DAsync_v2( + pCopy: *const CUDA_MEMCPY2D, + hStream: CUstream, +) -> CUresult { r#impl::unimplemented() } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuMemcpy3DAsync_v2(pCopy: *const CUDA_MEMCPY3D, hStream: CUstream) -> CUresult { +pub extern "system" fn cuMemcpy3DAsync_v2( + pCopy: *const CUDA_MEMCPY3D, + hStream: CUstream, +) -> CUresult { r#impl::unimplemented() } @@ -3406,7 +3424,9 @@ pub extern "system" fn cuStreamBeginCapture_v2( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuThreadExchangeStreamCaptureMode(mode: *mut CUstreamCaptureMode) -> CUresult { +pub extern "system" fn cuThreadExchangeStreamCaptureMode( + mode: *mut CUstreamCaptureMode, +) -> CUresult { r#impl::unimplemented() } @@ -3449,7 +3469,7 @@ pub extern "system" fn cuStreamQuery(hStream: CUstream) -> CUresult { #[cfg_attr(not(test), no_mangle)] pub extern "system" fn cuStreamSynchronize(hStream: CUstream) -> CUresult { - CUresult::CUDA_SUCCESS + r#impl::stream::synchronize(hStream.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -3481,7 +3501,10 @@ pub extern "system" fn cuStreamSetAttribute( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuEventCreate(phEvent: *mut CUevent, Flags: ::std::os::raw::c_uint) -> CUresult { +pub extern "system" fn cuEventCreate( + phEvent: *mut CUevent, + Flags: ::std::os::raw::c_uint, +) -> CUresult { r#impl::unimplemented() } @@ -3652,7 +3675,10 @@ pub extern "system" fn cuFuncSetCacheConfig(hfunc: CUfunction, config: CUfunc_ca } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuFuncSetSharedMemConfig(hfunc: CUfunction, config: CUsharedconfig) -> CUresult { +pub extern "system" fn cuFuncSetSharedMemConfig( + hfunc: CUfunction, + config: CUsharedconfig, +) -> CUresult { r#impl::unimplemented() } @@ -3770,7 +3796,10 @@ pub extern "system" fn cuFuncSetSharedSize( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuParamSetSize(hfunc: CUfunction, numbytes: ::std::os::raw::c_uint) -> CUresult { +pub extern "system" fn cuParamSetSize( + hfunc: CUfunction, + numbytes: ::std::os::raw::c_uint, +) -> CUresult { r#impl::unimplemented() } @@ -3836,7 +3865,10 @@ pub extern "system" fn cuParamSetTexRef( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuGraphCreate(phGraph: *mut CUgraph, flags: ::std::os::raw::c_uint) -> CUresult { +pub extern "system" fn cuGraphCreate( + phGraph: *mut CUgraph, + flags: ::std::os::raw::c_uint, +) -> CUresult { r#impl::unimplemented() } @@ -3980,7 +4012,10 @@ pub extern "system" fn cuGraphAddEmptyNode( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuGraphClone(phGraphClone: *mut CUgraph, originalGraph: CUgraph) -> CUresult { +pub extern "system" fn cuGraphClone( + phGraphClone: *mut CUgraph, + originalGraph: CUgraph, +) -> CUresult { r#impl::unimplemented() } @@ -3994,7 +4029,10 @@ pub extern "system" fn cuGraphNodeFindInClone( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuGraphNodeGetType(hNode: CUgraphNode, type_: *mut CUgraphNodeType) -> CUresult { +pub extern "system" fn cuGraphNodeGetType( + hNode: CUgraphNode, + type_: *mut CUgraphNodeType, +) -> CUresult { r#impl::unimplemented() } @@ -4144,7 +4182,10 @@ pub extern "system" fn cuGraphExecUpdate( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuGraphKernelNodeCopyAttributes(dst: CUgraphNode, src: CUgraphNode) -> CUresult { +pub extern "system" fn cuGraphKernelNodeCopyAttributes( + dst: CUgraphNode, + src: CUgraphNode, +) -> CUresult { r#impl::unimplemented() } @@ -4284,7 +4325,10 @@ pub extern "system" fn cuTexRefSetFilterMode(hTexRef: CUtexref, fm: CUfilter_mod } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuTexRefSetMipmapFilterMode(hTexRef: CUtexref, fm: CUfilter_mode) -> CUresult { +pub extern "system" fn cuTexRefSetMipmapFilterMode( + hTexRef: CUtexref, + fm: CUfilter_mode, +) -> CUresult { r#impl::unimplemented() } @@ -4311,17 +4355,26 @@ pub extern "system" fn cuTexRefSetMaxAnisotropy( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuTexRefSetBorderColor(hTexRef: CUtexref, pBorderColor: *mut f32) -> CUresult { +pub extern "system" fn cuTexRefSetBorderColor( + hTexRef: CUtexref, + pBorderColor: *mut f32, +) -> CUresult { r#impl::unimplemented() } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuTexRefSetFlags(hTexRef: CUtexref, Flags: ::std::os::raw::c_uint) -> CUresult { +pub extern "system" fn cuTexRefSetFlags( + hTexRef: CUtexref, + Flags: ::std::os::raw::c_uint, +) -> CUresult { r#impl::unimplemented() } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuTexRefGetAddress_v2(pdptr: *mut CUdeviceptr, hTexRef: CUtexref) -> CUresult { +pub extern "system" fn cuTexRefGetAddress_v2( + pdptr: *mut CUdeviceptr, + hTexRef: CUtexref, +) -> CUresult { r#impl::unimplemented() } @@ -4348,7 +4401,10 @@ pub extern "system" fn cuTexRefGetAddressMode( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuTexRefGetFilterMode(pfm: *mut CUfilter_mode, hTexRef: CUtexref) -> CUresult { +pub extern "system" fn cuTexRefGetFilterMode( + pfm: *mut CUfilter_mode, + hTexRef: CUtexref, +) -> CUresult { r#impl::unimplemented() } @@ -4392,7 +4448,10 @@ pub extern "system" fn cuTexRefGetMaxAnisotropy( } #[cfg_attr(not(test), no_mangle)] -pub extern "system" fn cuTexRefGetBorderColor(pBorderColor: *mut f32, hTexRef: CUtexref) -> CUresult { +pub extern "system" fn cuTexRefGetBorderColor( + pBorderColor: *mut f32, + hTexRef: CUtexref, +) -> CUresult { r#impl::unimplemented() } diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 5ef427e..9ea0874 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -257,9 +257,14 @@ pub fn detach(pctx: *mut Context) -> Result<(), CUresult> { })? } -pub(crate) fn synchronize() -> CUresult { - // TODO: change the implementation once we do async stream operations - CUresult::CUDA_SUCCESS +pub(crate) fn synchronize() -> Result<(), CUresult> { + GlobalState::lock_current_context(|ctx| { + ctx.default_stream.synchronize()?; + for stream in ctx.streams.iter().copied() { + unsafe { &mut *stream }.synchronize()?; + } + Ok(()) + })? } #[cfg(test)] diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 63bf39f..0594252 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -1,4 +1,4 @@ -use super::{context, CUresult, GlobalState}; +use super::{context, transmute_lifetime, transmute_lifetime_mut, CUresult, GlobalState}; use crate::cuda; use cuda::{CUdevice_attribute, CUuuid_st}; use std::{ @@ -21,6 +21,7 @@ pub struct Device { pub default_queue: l0::CommandQueue<'static>, pub l0_context: l0::Context, pub primary_context: context::Context, + pub event_pool: DynamicEventPool, properties: Option<Box<l0::sys::ze_device_properties_t>>, image_properties: Option<Box<l0::sys::ze_device_image_properties_t>>, memory_properties: Option<Vec<l0::sys::ze_device_memory_properties_t>>, @@ -42,12 +43,14 @@ impl Device { true, ptr::null_mut(), )?); + let event_pool = DynamicEventPool::new(l0_dev, transmute_lifetime(&ctx))?; Ok(Self { index: Index(idx as c_int), base: l0_dev, default_queue: queue, l0_context: ctx, primary_context: primary_context, + event_pool, properties: None, image_properties: None, memory_properties: None, @@ -395,8 +398,103 @@ pub(crate) fn primary_ctx_release_v2(_dev_idx: Index) -> CUresult { CUresult::CUDA_SUCCESS } +pub struct DynamicEventPool { + count: usize, + events: Vec<DynamicEventPoolEntry>, +} + +impl DynamicEventPool { + fn new(dev: l0::Device, ctx: &'static l0::Context) -> l0::Result<Self> { + Ok(DynamicEventPool { + count: 0, + events: vec![DynamicEventPoolEntry::new(dev, ctx)?], + }) + } + + pub fn get( + &'static mut self, + dev: l0::Device, + ctx: &'static l0::Context, + ) -> l0::Result<(l0::Event<'static>, u64)> { + self.count += 1; + let events = unsafe { transmute_lifetime_mut(&mut self.events) }; + let (global_idx, (ev, local_idx)) = { + for (idx, entry) in self.events.iter_mut().enumerate() { + if let Some((ev, local_idx)) = entry.get()? { + let marker = (idx << 32) as u64 | local_idx as u64; + return Ok((ev, marker)); + } + } + events.push(DynamicEventPoolEntry::new(dev, ctx)?); + let global_idx = (events.len() - 1) as u64; + (global_idx, events.last_mut().unwrap().get()?.unwrap()) + }; + let marker = (global_idx << 32) | local_idx as u64; + Ok((ev, marker)) + } + + pub fn mark_as_free(&mut self, marker: u64) { + let global_idx = (marker >> 32) as u32; + self.events[global_idx as usize].mark_as_free(marker as u32); + self.count -= 1; + // TODO: clean up empty entries + } +} + +const DYNAMIC_EVENT_POOL_ENTRY_SIZE: usize = 448; +const DYNAMIC_EVENT_POOL_ENTRY_BITMAP_SIZE: usize = + DYNAMIC_EVENT_POOL_ENTRY_SIZE / (mem::size_of::<u64>() * 8); +#[repr(C)] +#[repr(align(64))] +struct DynamicEventPoolEntry { + event_pool: l0::EventPool<'static>, + bit_map: [u64; DYNAMIC_EVENT_POOL_ENTRY_BITMAP_SIZE], +} + +impl DynamicEventPoolEntry { + fn new(dev: l0::Device, ctx: &'static l0::Context) -> l0::Result<Self> { + Ok(DynamicEventPoolEntry { + event_pool: l0::EventPool::new( + ctx, + DYNAMIC_EVENT_POOL_ENTRY_SIZE as u32, + Some(&[dev]), + )?, + bit_map: [0; DYNAMIC_EVENT_POOL_ENTRY_BITMAP_SIZE], + }) + } + + fn get(&'static mut self) -> l0::Result<Option<(l0::Event<'static>, u32)>> { + for (idx, value) in self.bit_map.iter_mut().enumerate() { + let shift = first_index_of_zero_u64(*value); + if shift == 64 { + continue; + } + *value = *value | (1u64 << shift); + let entry_index = (idx as u32 * 64u32) + shift; + let event = l0::Event::new(&self.event_pool, entry_index)?; + return Ok(Some((event, entry_index))); + } + Ok(None) + } + + fn mark_as_free(&mut self, idx: u32) { + let value = &mut self.bit_map[idx as usize / 64]; + let shift = idx % 64; + *value = *value & !(1 << shift); + } +} + +fn first_index_of_zero_u64(x: u64) -> u32 { + let x = !x; + (x & x.wrapping_neg()).trailing_zeros() +} + #[cfg(test)] mod test { + use std::mem; + + use super::DynamicEventPoolEntry; + use super::super::test::CudaDriverFns; use super::super::CUresult; @@ -413,4 +511,10 @@ mod test { assert_eq!(flags, 0); assert_eq!(active, 0); } + + #[test] + pub fn dynamic_event_pool_page_is_64b() { + assert_eq!(mem::size_of::<DynamicEventPoolEntry>(), 64); + assert_eq!(mem::align_of::<DynamicEventPoolEntry>(), 64); + } } diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 238d68e..81b4f31 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -11,13 +11,10 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> }
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
- GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let cmd_list = stream.command_list()?;
- unsafe { cmd_list.append_memory_copy_raw(dst, src, bytesize, None, &mut [])? };
- cmd_list.close()?;
- stream.queue.execute_and_synchronize(cmd_list)?;
- Ok::<_, CUresult>(())
- })?
+ GlobalState::lock_enqueue(stream::CU_STREAM_LEGACY, |cmd_list, signal, wait| {
+ unsafe { cmd_list.append_memory_copy_raw(dst, src, bytesize, Some(signal), wait)? };
+ Ok::<_, l0::sys::ze_result_t>(())
+ })
}
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 55c047f..48b5fb5 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -273,6 +273,32 @@ impl GlobalState { } } + fn lock_enqueue( + stream: *mut stream::Stream, + f: impl FnOnce( + &mut l0::CommandList, + &l0::Event<'static>, + &[&l0::Event<'static>], + ) -> l0::Result<()>, + ) -> Result<(), CUresult> { + Self::lock_stream(stream, |stream_data| { + let l0_dev = unsafe { (*(*stream_data.context).device).base }; + let l0_ctx = unsafe { &mut (*(*stream_data.context).device).l0_context }; + let event_pool = unsafe { &mut (*(*stream_data.context).device).event_pool }; + let mut cmd_list = unsafe { mem::transmute(stream_data.command_list()?) }; + stream_data + .process_finished_events(&mut |(_, marker)| event_pool.mark_as_free(marker))?; + let prev_event = stream_data.get_last_event(); + let prev_event_array = prev_event.map(|e| [e]); + let empty = []; + let prev_event_slice = prev_event_array.as_ref().map_or(&empty[..], |arr| &arr[..]); + let (new_event, new_marker) = event_pool.get(l0_dev, l0_ctx)?; + f(&mut cmd_list, &new_event, prev_event_slice)?; + stream_data.push_event((new_event, new_marker)); + Ok(()) + })? + } + fn lock_function<T>( func: *mut function::Function, f: impl FnOnce(&mut function::FunctionData) -> T, @@ -421,6 +447,10 @@ pub(crate) fn get_error_string(error: CUresult, str: *mut *const i8) -> CUresult } } +unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T { + mem::transmute(t) +} + unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T { mem::transmute(t) } diff --git a/zluda/src/impl/stream.rs b/zluda/src/impl/stream.rs index 0fafe92..11f1869 100644 --- a/zluda/src/impl/stream.rs +++ b/zluda/src/impl/stream.rs @@ -2,7 +2,7 @@ use super::{ context::{Context, ContextData}, CUresult, GlobalState, }; -use std::{mem, ptr}; +use std::{collections::VecDeque, mem, ptr}; use super::{HasLivenessCookie, LiveCheck}; @@ -34,21 +34,27 @@ impl HasLivenessCookie for StreamData { pub struct StreamData { pub context: *mut ContextData, pub queue: l0::CommandQueue<'static>, + pub prev_events: VecDeque<(l0::Event<'static>, u64)>, } impl StreamData { - pub fn new_unitialized(ctx: &'static l0::Context, dev: l0::Device) -> Result<Self, CUresult> { + pub fn new_unitialized( + ctx: &'static l0::Context, + device: l0::Device, + ) -> Result<Self, CUresult> { Ok(StreamData { context: ptr::null_mut(), - queue: l0::CommandQueue::new(ctx, dev)?, + queue: l0::CommandQueue::new(ctx, device)?, + prev_events: VecDeque::new(), }) } pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> { let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context; - let l0_dev = unsafe { &*ctx.device }.base; + let device = unsafe { &*ctx.device }.base; Ok(StreamData { context: ctx as *mut _, - queue: l0::CommandQueue::new(l0_ctx, l0_dev)?, + queue: l0::CommandQueue::new(l0_ctx, device)?, + prev_events: VecDeque::new(), }) } @@ -57,6 +63,39 @@ impl StreamData { let dev = unsafe { &mut *ctx.device }; l0::CommandList::new(&mut dev.l0_context, dev.base) } + + pub fn process_finished_events( + &mut self, + f: &mut impl FnMut((l0::Event<'static>, u64)), + ) -> l0::Result<()> { + loop { + match self.prev_events.get(0) { + None => return Ok(()), + Some((ev, _)) => { + if ev.is_ready()? { + f(self.prev_events.pop_front().unwrap()); + } else { + return Ok(()); + } + } + } + } + } + + pub fn get_last_event(&self) -> Option<&l0::Event<'static>> { + self.prev_events.iter().next_back().map(|(ev, _)| ev) + } + + pub fn push_event(&mut self, ev: (l0::Event<'static>, u64)) { + self.prev_events.push_back(ev); + } + + pub fn synchronize(&mut self) -> l0::Result<()> { + self.queue.synchronize(u64::MAX)?; + let event_pool = unsafe { &mut (*(*self.context).device).event_pool }; + self.process_finished_events(&mut |(_, marker)| event_pool.mark_as_free(marker))?; + Ok(()) + } } impl Drop for StreamData { @@ -102,6 +141,13 @@ pub(crate) fn destroy_v2(pstream: *mut Stream) -> Result<(), CUresult> { GlobalState::lock(|_| Stream::destroy_impl(pstream))? } +pub(crate) fn synchronize(pstream: *mut Stream) -> Result<(), CUresult> { + if pstream == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock_stream(pstream, |stream_data| Ok(stream_data.synchronize()?))? +} + #[cfg(test)] mod test { use crate::cuda::CUstream; |