aboutsummaryrefslogtreecommitdiffhomepage
path: root/level_zero
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-05-28 00:14:45 +0200
committerAndrzej Janik <[email protected]>2021-05-28 00:14:45 +0200
commit2fc7af0434256a353af130708a2dafb97be99d24 (patch)
treeac9321f4b909824c37fbd21b6de6c76bff985d40 /level_zero
parente40785aa7491de16c65de7aa599105102ffa7355 (diff)
downloadZLUDA-2fc7af0434256a353af130708a2dafb97be99d24.tar.gz
ZLUDA-2fc7af0434256a353af130708a2dafb97be99d24.zip
Fix level zero bindings
Diffstat (limited to 'level_zero')
-rw-r--r--level_zero/src/ze.rs284
1 files changed, 177 insertions, 107 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs
index 88adfe6..f4cd0ae 100644
--- a/level_zero/src/ze.rs
+++ b/level_zero/src/ze.rs
@@ -737,129 +737,122 @@ impl<'a> CommandList<'a> {
Ok(unsafe { Self::from_ffi(result) })
}
- pub fn append_memory_copy<'event, T: 'a, Dst: Into<Slice<'a, T>>, Src: Into<Slice<'a, T>>>(
- &'a self,
+ pub unsafe fn append_memory_copy<
+ 'dep,
+ T: 'a + 'dep + Copy + Sized,
+ Dst: Into<Slice<'dep, T>>,
+ Src: Into<Slice<'dep, T>>,
+ >(
+ &self,
dst: Dst,
src: Src,
- signal: Option<&Event<'event>>,
- wait: &[Event<'event>],
- ) -> Result<()>
- where
- 'event: 'a,
- {
+ signal: Option<&Event<'dep>>,
+ wait: &[&'dep Event<'dep>],
+ ) -> Result<()> {
let dst = dst.into();
let src = src.into();
let elements = std::cmp::min(dst.len(), src.len());
let length = elements * mem::size_of::<T>();
- unsafe {
- self.append_memory_copy_unsafe(dst.as_mut_ptr(), src.as_ptr(), length, signal, wait)
- }
+ self.append_memory_copy_raw(dst.as_mut_ptr(), src.as_ptr(), length, signal, wait)
}
- pub unsafe fn append_memory_copy_unsafe(
+ pub unsafe fn append_memory_copy_raw(
&self,
dst: *mut c_void,
src: *const c_void,
length: usize,
signal: Option<&Event>,
- wait: &[Event],
+ wait: &[&Event],
) -> Result<()> {
- let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut());
- let (wait_len, wait_ptr) = Event::raw_slice(wait);
- check!(sys::zeCommandListAppendMemoryCopy(
- self.as_ffi(),
- dst,
- src,
- length,
- signal_event,
- wait_len,
- wait_ptr
- ));
- Ok(())
- }
-
- pub fn append_memory_fill<'event, T: 'a, Dst: Into<Slice<'a, T>>>(
+ let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi());
+ Event::with_raw_slice(wait, |wait_len, wait_ptr| {
+ check!(sys::zeCommandListAppendMemoryCopy(
+ self.as_ffi(),
+ dst,
+ src,
+ length,
+ signal_event,
+ wait_len,
+ wait_ptr
+ ));
+ Ok(())
+ })
+ }
+
+ pub unsafe fn append_memory_fill<'dep, T: Copy + Sized + 'dep, Dst: Into<Slice<'dep, T>>>(
&'a self,
dst: Dst,
- pattern: u8,
- signal: Option<&Event<'event>>,
- wait: &[Event<'event>],
- ) -> Result<()>
- where
- 'event: 'a,
- {
+ pattern: &T,
+ signal: Option<&Event<'dep>>,
+ wait: &[&'dep Event<'dep>],
+ ) -> Result<()> {
let dst = dst.into();
- let raw_pattern = &pattern as *const u8 as *const _;
- let signal_event = signal
- .map(|e| unsafe { e.as_ffi() })
- .unwrap_or(ptr::null_mut());
- let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) };
- let byte_len = dst.len() * mem::size_of::<T>();
- check!(sys::zeCommandListAppendMemoryFill(
- self.as_ffi(),
- dst.as_mut_ptr(),
- raw_pattern,
- mem::size_of::<u8>(),
- byte_len,
- signal_event,
- wait_len,
- wait_ptr
- ));
- Ok(())
- }
-
- pub unsafe fn append_memory_fill_unsafe<T: Copy + Sized>(
+ let raw_pattern = pattern as *const _ as *const _;
+ let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi());
+ Event::with_raw_slice(wait, |wait_len, wait_ptr| {
+ check!(sys::zeCommandListAppendMemoryFill(
+ self.as_ffi(),
+ dst.as_mut_ptr(),
+ raw_pattern,
+ mem::size_of::<T>(),
+ dst.len() * mem::size_of::<T>(),
+ signal_event,
+ wait_len,
+ wait_ptr
+ ));
+ Ok(())
+ })
+ }
+
+ pub unsafe fn append_memory_fill_raw(
&self,
dst: *mut c_void,
- pattern: &T,
- byte_size: usize,
+ pattern: *mut c_void,
+ pattern_size: usize,
+ size: usize,
signal: Option<&Event>,
- wait: &[Event],
+ wait: &[&Event],
) -> Result<()> {
- let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut());
- let (wait_len, wait_ptr) = Event::raw_slice(wait);
- check!(sys::zeCommandListAppendMemoryFill(
- self.as_ffi(),
- dst,
- pattern as *const T as *const _,
- mem::size_of::<T>(),
- byte_size,
- signal_event,
- wait_len,
- wait_ptr
- ));
- Ok(())
+ let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi());
+ Event::with_raw_slice(wait, |wait_len, wait_ptr| {
+ check!(sys::zeCommandListAppendMemoryFill(
+ self.as_ffi(),
+ dst,
+ pattern,
+ pattern_size,
+ size,
+ signal_event,
+ wait_len,
+ wait_ptr
+ ));
+ Ok(())
+ })
}
- pub fn append_launch_kernel<'event, 'kernel>(
- &'a self,
- kernel: &'kernel Kernel,
+ pub unsafe fn append_launch_kernel(
+ &self,
+ kernel: &Kernel,
group_count: &[u32; 3],
- signal: Option<&Event<'event>>,
- wait: &[Event<'event>],
- ) -> Result<()>
- where
- 'event: 'a,
- 'kernel: 'a,
- {
+ signal: Option<&Event>,
+ wait: &[&Event],
+ ) -> Result<()> {
let gr_count = sys::ze_group_count_t {
groupCountX: group_count[0],
groupCountY: group_count[1],
groupCountZ: group_count[2],
};
- let signal_event = signal
- .map(|e| unsafe { e.as_ffi() })
- .unwrap_or(ptr::null_mut());
- let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) };
- check!(sys::zeCommandListAppendLaunchKernel(
- self.as_ffi(),
- kernel.as_ffi(),
- &gr_count,
- signal_event,
- wait_len,
- wait_ptr,
- ));
- Ok(())
+ let signal_event = signal.map_or(ptr::null_mut(), |e| e.as_ffi());
+ Event::with_raw_slice(wait, |wait_len, wait_ptr| {
+ check!(sys::zeCommandListAppendLaunchKernel(
+ self.as_ffi(),
+ kernel.as_ffi(),
+ &gr_count,
+ signal_event,
+ wait_len,
+ wait_ptr,
+ ));
+ Ok(())
+ })
}
pub fn close(&self) -> Result<()> {
@@ -875,17 +868,86 @@ impl<'a> Drop for CommandList<'a> {
}
}
+pub struct CommandListBuilder<'a>(CommandList<'a>);
+
+unsafe impl<'a> Send for CommandListBuilder<'a> {}
+
+impl<'a> CommandListBuilder<'a> {
+ pub fn new(ctx: &'a Context, dev: Device) -> Result<Self> {
+ Ok(CommandListBuilder(CommandList::new(ctx, dev)?))
+ }
+
+ pub fn append_memory_copy<
+ 'dep,
+ 'result,
+ T: 'dep + Copy + Sized,
+ Dst: Into<Slice<'dep, T>>,
+ Src: Into<Slice<'dep, T>>,
+ >(
+ self,
+ dst: Dst,
+ src: Src,
+ signal: Option<&'dep Event<'dep>>,
+ wait: &[&'dep Event<'dep>],
+ ) -> Result<CommandListBuilder<'result>>
+ where
+ 'a: 'result,
+ 'dep: 'result,
+ {
+ unsafe { self.0.append_memory_copy(dst, src, signal, wait) }?;
+ Ok(self)
+ }
+
+ pub fn append_memory_fill<'dep, 'result, T: 'dep + Copy + Sized, Dst: Into<Slice<'dep, T>>>(
+ self,
+ dst: Dst,
+ pattern: &T,
+ signal: Option<&Event<'dep>>,
+ wait: &[&'dep Event<'dep>],
+ ) -> Result<CommandListBuilder<'result>>
+ where
+ 'a: 'result,
+ 'dep: 'result,
+ {
+ unsafe { self.0.append_memory_fill(dst, pattern, signal, wait) }?;
+ Ok(self)
+ }
+
+ pub fn append_launch_kernel<'dep, 'result>(
+ self,
+ kernel: &'dep Kernel,
+ group_count: &[u32; 3],
+ signal: Option<&Event<'dep>>,
+ wait: &[&'dep Event<'dep>],
+ ) -> Result<CommandListBuilder<'result>>
+ where
+ 'a: 'result,
+ 'dep: 'result,
+ {
+ unsafe {
+ self.0
+ .append_launch_kernel(kernel, group_count, signal, wait)
+ }?;
+ Ok(self)
+ }
+
+ pub fn execute(self, q: &'a CommandQueue<'a>) -> Result<FenceGuard<'a>> {
+ self.0.close()?;
+ q.execute_and_synchronize(self.0)
+ }
+}
+
#[derive(Copy, Clone)]
-pub struct Slice<'a, T> {
+pub struct Slice<'a, T: Copy + Sized> {
ptr: *mut c_void,
len: usize,
marker: PhantomData<&'a T>,
}
-unsafe impl<'a, T> Send for Slice<'a, T> {}
-unsafe impl<'a, T> Sync for Slice<'a, T> {}
+unsafe impl<'a, T: Copy + Sized> Send for Slice<'a, T> {}
+unsafe impl<'a, T: Copy + Sized> Sync for Slice<'a, T> {}
-impl<'a, T> Slice<'a, T> {
+impl<'a, T: Copy + Sized> Slice<'a, T> {
pub unsafe fn new(ptr: *mut c_void, len: usize) -> Self {
Self {
ptr,
@@ -907,7 +969,7 @@ impl<'a, T> Slice<'a, T> {
}
}
-impl<'a, T> From<&'a [T]> for Slice<'a, T> {
+impl<'a, T: Copy + Sized> From<&'a [T]> for Slice<'a, T> {
fn from(s: &'a [T]) -> Self {
Slice {
ptr: s.as_ptr() as *mut _,
@@ -917,7 +979,7 @@ impl<'a, T> From<&'a [T]> for Slice<'a, T> {
}
}
-impl<'a, T: Copy> From<&'a DeviceBuffer<'a, T>> for Slice<'a, T> {
+impl<'a, T: Copy + Sized> From<&'a DeviceBuffer<'a, T>> for Slice<'a, T> {
fn from(b: &'a DeviceBuffer<'a, T>) -> Self {
Slice {
ptr: b.ptr,
@@ -996,13 +1058,21 @@ impl<'a> Event<'a> {
Ok(unsafe { Self::from_ffi(result) })
}
- unsafe fn raw_slice(e: &[Event]) -> (u32, *mut sys::ze_event_handle_t) {
- let ptr = if e.len() == 0 {
- ptr::null()
- } else {
- e.as_ptr()
+ unsafe fn with_raw_slice<'x, T>(
+ events: &[&Event<'x>],
+ f: impl FnOnce(u32, *mut sys::ze_event_handle_t) -> T,
+ ) -> T {
+ let (ptr, ev_vec) = match events {
+ [] => (ptr::null_mut(), None),
+ [e] => (&e.0 as *const _ as *mut _, None),
+ _ => {
+ let mut ev_vec = events.iter().map(|e| e.as_ffi()).collect::<Vec<_>>();
+ (ev_vec.as_mut_ptr(), Some(ev_vec))
+ }
};
- (e.len() as u32, ptr as *mut sys::ze_event_handle_t)
+ let result = f(events.len() as u32, ptr);
+ drop(ev_vec);
+ result
}
}
@@ -1042,7 +1112,7 @@ impl<'a> Kernel<'a> {
Ok(())
}
- pub fn set_arg_buffer<T: 'a, Buff: Into<Slice<'a, T>>>(
+ pub fn set_arg_buffer<T: 'a + Copy + Sized, Buff: Into<Slice<'a, T>>>(
&self,
index: u32,
buff: Buff,