aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-07-04 12:54:27 +0200
committerAndrzej Janik <[email protected]>2021-07-04 12:54:27 +0200
commitb460e359ae5eb21ca75dec655daa671e559f6a45 (patch)
treeb362e0efdeff572f75ab3bda02735c91c0092681
parentad2059872a43315ecb5361db1ab2c74d55363fae (diff)
downloadZLUDA-b460e359ae5eb21ca75dec655daa671e559f6a45.tar.gz
ZLUDA-b460e359ae5eb21ca75dec655daa671e559f6a45.zip
First attempt at async host side
-rw-r--r--level_zero/src/ze.rs14
-rw-r--r--zluda/src/cuda.rs105
-rw-r--r--zluda/src/impl/context.rs11
-rw-r--r--zluda/src/impl/device.rs106
-rw-r--r--zluda/src/impl/memory.rs11
-rw-r--r--zluda/src/impl/mod.rs30
-rw-r--r--zluda/src/impl/stream.rs56
7 files changed, 294 insertions, 39 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs
index 703f2ce..16a98a0 100644
--- a/level_zero/src/ze.rs
+++ b/level_zero/src/ze.rs
@@ -325,6 +325,11 @@ impl<'a> CommandQueue<'a> {
));
Ok(())
}
+
+ pub fn synchronize(&self, timeout_ns: u64) -> Result<()> {
+ check!(sys::zeCommandQueueSynchronize(self.as_ffi(), timeout_ns));
+ Ok(())
+ }
}
impl<'a> Drop for CommandQueue<'a> {
@@ -1097,6 +1102,15 @@ impl<'a> Event<'a> {
Ok(unsafe { Self::from_ffi(result) })
}
+ pub fn is_ready(&self) -> Result<bool> {
+ let status = unsafe { sys::zeEventQueryStatus(self.as_ffi()) };
+ match status {
+ sys::ze_result_t::ZE_RESULT_SUCCESS => Ok(true),
+ sys::ze_result_t::ZE_RESULT_NOT_READY => Ok(false),
+ err => Err(err),
+ }
+ }
+
unsafe fn with_raw_slice<'x, T>(
events: &[&Event<'x>],
f: impl FnOnce(u32, *mut sys::ze_event_handle_t) -> T,
diff --git a/zluda/src/cuda.rs b/zluda/src/cuda.rs
index 9e7cbff..1bf10fd 100644
--- a/zluda/src/cuda.rs
+++ b/zluda/src/cuda.rs
@@ -2186,7 +2186,7 @@ pub extern "system" fn cuGetErrorString(
error: CUresult,
pStr: *mut *const ::std::os::raw::c_char,
) -> CUresult {
- r#impl::get_error_string(error, pStr).encuda()
+ r#impl::get_error_string(error, pStr).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -2209,7 +2209,10 @@ pub extern "system" fn cuDriverGetVersion(driverVersion: *mut ::std::os::raw::c_
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuDeviceGet(device: *mut CUdevice, ordinal: ::std::os::raw::c_int) -> CUresult {
+pub extern "system" fn cuDeviceGet(
+ device: *mut CUdevice,
+ ordinal: ::std::os::raw::c_int,
+) -> CUresult {
r#impl::device::get(device.decuda(), ordinal).encuda()
}
@@ -2374,7 +2377,7 @@ pub extern "system" fn cuCtxGetFlags(flags: *mut ::std::os::raw::c_uint) -> CUre
#[cfg_attr(not(test), no_mangle)]
pub extern "system" fn cuCtxSynchronize() -> CUresult {
- r#impl::context::synchronize()
+ r#impl::context::synchronize().encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -2429,7 +2432,10 @@ pub extern "system" fn cuCtxResetPersistingL2Cache() -> CUresult {
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuCtxAttach(pctx: *mut CUcontext, flags: ::std::os::raw::c_uint) -> CUresult {
+pub extern "system" fn cuCtxAttach(
+ pctx: *mut CUcontext,
+ flags: ::std::os::raw::c_uint,
+) -> CUresult {
r#impl::context::attach(pctx.decuda(), flags).encuda()
}
@@ -2667,7 +2673,10 @@ pub extern "system" fn cuDeviceGetPCIBusId(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuIpcGetEventHandle(pHandle: *mut CUipcEventHandle, event: CUevent) -> CUresult {
+pub extern "system" fn cuIpcGetEventHandle(
+ pHandle: *mut CUipcEventHandle,
+ event: CUevent,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -2680,7 +2689,10 @@ pub extern "system" fn cuIpcOpenEventHandle(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuIpcGetMemHandle(pHandle: *mut CUipcMemHandle, dptr: CUdeviceptr) -> CUresult {
+pub extern "system" fn cuIpcGetMemHandle(
+ pHandle: *mut CUipcMemHandle,
+ dptr: CUdeviceptr,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -2930,12 +2942,18 @@ pub extern "system" fn cuMemcpyAtoHAsync_v2(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuMemcpy2DAsync_v2(pCopy: *const CUDA_MEMCPY2D, hStream: CUstream) -> CUresult {
+pub extern "system" fn cuMemcpy2DAsync_v2(
+ pCopy: *const CUDA_MEMCPY2D,
+ hStream: CUstream,
+) -> CUresult {
r#impl::unimplemented()
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuMemcpy3DAsync_v2(pCopy: *const CUDA_MEMCPY3D, hStream: CUstream) -> CUresult {
+pub extern "system" fn cuMemcpy3DAsync_v2(
+ pCopy: *const CUDA_MEMCPY3D,
+ hStream: CUstream,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -3406,7 +3424,9 @@ pub extern "system" fn cuStreamBeginCapture_v2(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuThreadExchangeStreamCaptureMode(mode: *mut CUstreamCaptureMode) -> CUresult {
+pub extern "system" fn cuThreadExchangeStreamCaptureMode(
+ mode: *mut CUstreamCaptureMode,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -3449,7 +3469,7 @@ pub extern "system" fn cuStreamQuery(hStream: CUstream) -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "system" fn cuStreamSynchronize(hStream: CUstream) -> CUresult {
- CUresult::CUDA_SUCCESS
+ r#impl::stream::synchronize(hStream.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -3481,7 +3501,10 @@ pub extern "system" fn cuStreamSetAttribute(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuEventCreate(phEvent: *mut CUevent, Flags: ::std::os::raw::c_uint) -> CUresult {
+pub extern "system" fn cuEventCreate(
+ phEvent: *mut CUevent,
+ Flags: ::std::os::raw::c_uint,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -3652,7 +3675,10 @@ pub extern "system" fn cuFuncSetCacheConfig(hfunc: CUfunction, config: CUfunc_ca
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuFuncSetSharedMemConfig(hfunc: CUfunction, config: CUsharedconfig) -> CUresult {
+pub extern "system" fn cuFuncSetSharedMemConfig(
+ hfunc: CUfunction,
+ config: CUsharedconfig,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -3770,7 +3796,10 @@ pub extern "system" fn cuFuncSetSharedSize(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuParamSetSize(hfunc: CUfunction, numbytes: ::std::os::raw::c_uint) -> CUresult {
+pub extern "system" fn cuParamSetSize(
+ hfunc: CUfunction,
+ numbytes: ::std::os::raw::c_uint,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -3836,7 +3865,10 @@ pub extern "system" fn cuParamSetTexRef(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuGraphCreate(phGraph: *mut CUgraph, flags: ::std::os::raw::c_uint) -> CUresult {
+pub extern "system" fn cuGraphCreate(
+ phGraph: *mut CUgraph,
+ flags: ::std::os::raw::c_uint,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -3980,7 +4012,10 @@ pub extern "system" fn cuGraphAddEmptyNode(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuGraphClone(phGraphClone: *mut CUgraph, originalGraph: CUgraph) -> CUresult {
+pub extern "system" fn cuGraphClone(
+ phGraphClone: *mut CUgraph,
+ originalGraph: CUgraph,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -3994,7 +4029,10 @@ pub extern "system" fn cuGraphNodeFindInClone(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuGraphNodeGetType(hNode: CUgraphNode, type_: *mut CUgraphNodeType) -> CUresult {
+pub extern "system" fn cuGraphNodeGetType(
+ hNode: CUgraphNode,
+ type_: *mut CUgraphNodeType,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -4144,7 +4182,10 @@ pub extern "system" fn cuGraphExecUpdate(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuGraphKernelNodeCopyAttributes(dst: CUgraphNode, src: CUgraphNode) -> CUresult {
+pub extern "system" fn cuGraphKernelNodeCopyAttributes(
+ dst: CUgraphNode,
+ src: CUgraphNode,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -4284,7 +4325,10 @@ pub extern "system" fn cuTexRefSetFilterMode(hTexRef: CUtexref, fm: CUfilter_mod
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuTexRefSetMipmapFilterMode(hTexRef: CUtexref, fm: CUfilter_mode) -> CUresult {
+pub extern "system" fn cuTexRefSetMipmapFilterMode(
+ hTexRef: CUtexref,
+ fm: CUfilter_mode,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -4311,17 +4355,26 @@ pub extern "system" fn cuTexRefSetMaxAnisotropy(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuTexRefSetBorderColor(hTexRef: CUtexref, pBorderColor: *mut f32) -> CUresult {
+pub extern "system" fn cuTexRefSetBorderColor(
+ hTexRef: CUtexref,
+ pBorderColor: *mut f32,
+) -> CUresult {
r#impl::unimplemented()
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuTexRefSetFlags(hTexRef: CUtexref, Flags: ::std::os::raw::c_uint) -> CUresult {
+pub extern "system" fn cuTexRefSetFlags(
+ hTexRef: CUtexref,
+ Flags: ::std::os::raw::c_uint,
+) -> CUresult {
r#impl::unimplemented()
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuTexRefGetAddress_v2(pdptr: *mut CUdeviceptr, hTexRef: CUtexref) -> CUresult {
+pub extern "system" fn cuTexRefGetAddress_v2(
+ pdptr: *mut CUdeviceptr,
+ hTexRef: CUtexref,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -4348,7 +4401,10 @@ pub extern "system" fn cuTexRefGetAddressMode(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuTexRefGetFilterMode(pfm: *mut CUfilter_mode, hTexRef: CUtexref) -> CUresult {
+pub extern "system" fn cuTexRefGetFilterMode(
+ pfm: *mut CUfilter_mode,
+ hTexRef: CUtexref,
+) -> CUresult {
r#impl::unimplemented()
}
@@ -4392,7 +4448,10 @@ pub extern "system" fn cuTexRefGetMaxAnisotropy(
}
#[cfg_attr(not(test), no_mangle)]
-pub extern "system" fn cuTexRefGetBorderColor(pBorderColor: *mut f32, hTexRef: CUtexref) -> CUresult {
+pub extern "system" fn cuTexRefGetBorderColor(
+ pBorderColor: *mut f32,
+ hTexRef: CUtexref,
+) -> CUresult {
r#impl::unimplemented()
}
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs
index 5ef427e..9ea0874 100644
--- a/zluda/src/impl/context.rs
+++ b/zluda/src/impl/context.rs
@@ -257,9 +257,14 @@ pub fn detach(pctx: *mut Context) -> Result<(), CUresult> {
})?
}
-pub(crate) fn synchronize() -> CUresult {
- // TODO: change the implementation once we do async stream operations
- CUresult::CUDA_SUCCESS
+pub(crate) fn synchronize() -> Result<(), CUresult> {
+ GlobalState::lock_current_context(|ctx| {
+ ctx.default_stream.synchronize()?;
+ for stream in ctx.streams.iter().copied() {
+ unsafe { &mut *stream }.synchronize()?;
+ }
+ Ok(())
+ })?
}
#[cfg(test)]
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index 63bf39f..0594252 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -1,4 +1,4 @@
-use super::{context, CUresult, GlobalState};
+use super::{context, transmute_lifetime, transmute_lifetime_mut, CUresult, GlobalState};
use crate::cuda;
use cuda::{CUdevice_attribute, CUuuid_st};
use std::{
@@ -21,6 +21,7 @@ pub struct Device {
pub default_queue: l0::CommandQueue<'static>,
pub l0_context: l0::Context,
pub primary_context: context::Context,
+ pub event_pool: DynamicEventPool,
properties: Option<Box<l0::sys::ze_device_properties_t>>,
image_properties: Option<Box<l0::sys::ze_device_image_properties_t>>,
memory_properties: Option<Vec<l0::sys::ze_device_memory_properties_t>>,
@@ -42,12 +43,14 @@ impl Device {
true,
ptr::null_mut(),
)?);
+ let event_pool = DynamicEventPool::new(l0_dev, transmute_lifetime(&ctx))?;
Ok(Self {
index: Index(idx as c_int),
base: l0_dev,
default_queue: queue,
l0_context: ctx,
primary_context: primary_context,
+ event_pool,
properties: None,
image_properties: None,
memory_properties: None,
@@ -395,8 +398,103 @@ pub(crate) fn primary_ctx_release_v2(_dev_idx: Index) -> CUresult {
CUresult::CUDA_SUCCESS
}
+pub struct DynamicEventPool {
+ count: usize,
+ events: Vec<DynamicEventPoolEntry>,
+}
+
+impl DynamicEventPool {
+ fn new(dev: l0::Device, ctx: &'static l0::Context) -> l0::Result<Self> {
+ Ok(DynamicEventPool {
+ count: 0,
+ events: vec![DynamicEventPoolEntry::new(dev, ctx)?],
+ })
+ }
+
+ pub fn get(
+ &'static mut self,
+ dev: l0::Device,
+ ctx: &'static l0::Context,
+ ) -> l0::Result<(l0::Event<'static>, u64)> {
+ self.count += 1;
+ let events = unsafe { transmute_lifetime_mut(&mut self.events) };
+ let (global_idx, (ev, local_idx)) = {
+ for (idx, entry) in self.events.iter_mut().enumerate() {
+ if let Some((ev, local_idx)) = entry.get()? {
+ let marker = (idx << 32) as u64 | local_idx as u64;
+ return Ok((ev, marker));
+ }
+ }
+ events.push(DynamicEventPoolEntry::new(dev, ctx)?);
+ let global_idx = (events.len() - 1) as u64;
+ (global_idx, events.last_mut().unwrap().get()?.unwrap())
+ };
+ let marker = (global_idx << 32) | local_idx as u64;
+ Ok((ev, marker))
+ }
+
+ pub fn mark_as_free(&mut self, marker: u64) {
+ let global_idx = (marker >> 32) as u32;
+ self.events[global_idx as usize].mark_as_free(marker as u32);
+ self.count -= 1;
+ // TODO: clean up empty entries
+ }
+}
+
+const DYNAMIC_EVENT_POOL_ENTRY_SIZE: usize = 448;
+const DYNAMIC_EVENT_POOL_ENTRY_BITMAP_SIZE: usize =
+ DYNAMIC_EVENT_POOL_ENTRY_SIZE / (mem::size_of::<u64>() * 8);
+#[repr(C)]
+#[repr(align(64))]
+struct DynamicEventPoolEntry {
+ event_pool: l0::EventPool<'static>,
+ bit_map: [u64; DYNAMIC_EVENT_POOL_ENTRY_BITMAP_SIZE],
+}
+
+impl DynamicEventPoolEntry {
+ fn new(dev: l0::Device, ctx: &'static l0::Context) -> l0::Result<Self> {
+ Ok(DynamicEventPoolEntry {
+ event_pool: l0::EventPool::new(
+ ctx,
+ DYNAMIC_EVENT_POOL_ENTRY_SIZE as u32,
+ Some(&[dev]),
+ )?,
+ bit_map: [0; DYNAMIC_EVENT_POOL_ENTRY_BITMAP_SIZE],
+ })
+ }
+
+ fn get(&'static mut self) -> l0::Result<Option<(l0::Event<'static>, u32)>> {
+ for (idx, value) in self.bit_map.iter_mut().enumerate() {
+ let shift = first_index_of_zero_u64(*value);
+ if shift == 64 {
+ continue;
+ }
+ *value = *value | (1u64 << shift);
+ let entry_index = (idx as u32 * 64u32) + shift;
+ let event = l0::Event::new(&self.event_pool, entry_index)?;
+ return Ok(Some((event, entry_index)));
+ }
+ Ok(None)
+ }
+
+ fn mark_as_free(&mut self, idx: u32) {
+ let value = &mut self.bit_map[idx as usize / 64];
+ let shift = idx % 64;
+ *value = *value & !(1 << shift);
+ }
+}
+
+fn first_index_of_zero_u64(x: u64) -> u32 {
+ let x = !x;
+ (x & x.wrapping_neg()).trailing_zeros()
+}
+
#[cfg(test)]
mod test {
+ use std::mem;
+
+ use super::DynamicEventPoolEntry;
+
use super::super::test::CudaDriverFns;
use super::super::CUresult;
@@ -413,4 +511,10 @@ mod test {
assert_eq!(flags, 0);
assert_eq!(active, 0);
}
+
+ #[test]
+ pub fn dynamic_event_pool_page_is_64b() {
+ assert_eq!(mem::size_of::<DynamicEventPoolEntry>(), 64);
+ assert_eq!(mem::align_of::<DynamicEventPoolEntry>(), 64);
+ }
}
diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs
index 238d68e..81b4f31 100644
--- a/zluda/src/impl/memory.rs
+++ b/zluda/src/impl/memory.rs
@@ -11,13 +11,10 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult>
}
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
- GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let cmd_list = stream.command_list()?;
- unsafe { cmd_list.append_memory_copy_raw(dst, src, bytesize, None, &mut [])? };
- cmd_list.close()?;
- stream.queue.execute_and_synchronize(cmd_list)?;
- Ok::<_, CUresult>(())
- })?
+ GlobalState::lock_enqueue(stream::CU_STREAM_LEGACY, |cmd_list, signal, wait| {
+ unsafe { cmd_list.append_memory_copy_raw(dst, src, bytesize, Some(signal), wait)? };
+ Ok::<_, l0::sys::ze_result_t>(())
+ })
}
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs
index 55c047f..48b5fb5 100644
--- a/zluda/src/impl/mod.rs
+++ b/zluda/src/impl/mod.rs
@@ -273,6 +273,32 @@ impl GlobalState {
}
}
+ fn lock_enqueue(
+ stream: *mut stream::Stream,
+ f: impl FnOnce(
+ &mut l0::CommandList,
+ &l0::Event<'static>,
+ &[&l0::Event<'static>],
+ ) -> l0::Result<()>,
+ ) -> Result<(), CUresult> {
+ Self::lock_stream(stream, |stream_data| {
+ let l0_dev = unsafe { (*(*stream_data.context).device).base };
+ let l0_ctx = unsafe { &mut (*(*stream_data.context).device).l0_context };
+ let event_pool = unsafe { &mut (*(*stream_data.context).device).event_pool };
+ let mut cmd_list = unsafe { mem::transmute(stream_data.command_list()?) };
+ stream_data
+ .process_finished_events(&mut |(_, marker)| event_pool.mark_as_free(marker))?;
+ let prev_event = stream_data.get_last_event();
+ let prev_event_array = prev_event.map(|e| [e]);
+ let empty = [];
+ let prev_event_slice = prev_event_array.as_ref().map_or(&empty[..], |arr| &arr[..]);
+ let (new_event, new_marker) = event_pool.get(l0_dev, l0_ctx)?;
+ f(&mut cmd_list, &new_event, prev_event_slice)?;
+ stream_data.push_event((new_event, new_marker));
+ Ok(())
+ })?
+ }
+
fn lock_function<T>(
func: *mut function::Function,
f: impl FnOnce(&mut function::FunctionData) -> T,
@@ -421,6 +447,10 @@ pub(crate) fn get_error_string(error: CUresult, str: *mut *const i8) -> CUresult
}
}
+unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
+ mem::transmute(t)
+}
+
unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T {
mem::transmute(t)
}
diff --git a/zluda/src/impl/stream.rs b/zluda/src/impl/stream.rs
index 0fafe92..11f1869 100644
--- a/zluda/src/impl/stream.rs
+++ b/zluda/src/impl/stream.rs
@@ -2,7 +2,7 @@ use super::{
context::{Context, ContextData},
CUresult, GlobalState,
};
-use std::{mem, ptr};
+use std::{collections::VecDeque, mem, ptr};
use super::{HasLivenessCookie, LiveCheck};
@@ -34,21 +34,27 @@ impl HasLivenessCookie for StreamData {
pub struct StreamData {
pub context: *mut ContextData,
pub queue: l0::CommandQueue<'static>,
+ pub prev_events: VecDeque<(l0::Event<'static>, u64)>,
}
impl StreamData {
- pub fn new_unitialized(ctx: &'static l0::Context, dev: l0::Device) -> Result<Self, CUresult> {
+ pub fn new_unitialized(
+ ctx: &'static l0::Context,
+ device: l0::Device,
+ ) -> Result<Self, CUresult> {
Ok(StreamData {
context: ptr::null_mut(),
- queue: l0::CommandQueue::new(ctx, dev)?,
+ queue: l0::CommandQueue::new(ctx, device)?,
+ prev_events: VecDeque::new(),
})
}
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
- let l0_dev = unsafe { &*ctx.device }.base;
+ let device = unsafe { &*ctx.device }.base;
Ok(StreamData {
context: ctx as *mut _,
- queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
+ queue: l0::CommandQueue::new(l0_ctx, device)?,
+ prev_events: VecDeque::new(),
})
}
@@ -57,6 +63,39 @@ impl StreamData {
let dev = unsafe { &mut *ctx.device };
l0::CommandList::new(&mut dev.l0_context, dev.base)
}
+
+ pub fn process_finished_events(
+ &mut self,
+ f: &mut impl FnMut((l0::Event<'static>, u64)),
+ ) -> l0::Result<()> {
+ loop {
+ match self.prev_events.get(0) {
+ None => return Ok(()),
+ Some((ev, _)) => {
+ if ev.is_ready()? {
+ f(self.prev_events.pop_front().unwrap());
+ } else {
+ return Ok(());
+ }
+ }
+ }
+ }
+ }
+
+ pub fn get_last_event(&self) -> Option<&l0::Event<'static>> {
+ self.prev_events.iter().next_back().map(|(ev, _)| ev)
+ }
+
+ pub fn push_event(&mut self, ev: (l0::Event<'static>, u64)) {
+ self.prev_events.push_back(ev);
+ }
+
+ pub fn synchronize(&mut self) -> l0::Result<()> {
+ self.queue.synchronize(u64::MAX)?;
+ let event_pool = unsafe { &mut (*(*self.context).device).event_pool };
+ self.process_finished_events(&mut |(_, marker)| event_pool.mark_as_free(marker))?;
+ Ok(())
+ }
}
impl Drop for StreamData {
@@ -102,6 +141,13 @@ pub(crate) fn destroy_v2(pstream: *mut Stream) -> Result<(), CUresult> {
GlobalState::lock(|_| Stream::destroy_impl(pstream))?
}
+pub(crate) fn synchronize(pstream: *mut Stream) -> Result<(), CUresult> {
+ if pstream == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ GlobalState::lock_stream(pstream, |stream_data| Ok(stream_data.synchronize()?))?
+}
+
#[cfg(test)]
mod test {
use crate::cuda::CUstream;