summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--level_zero/src/ze.rs311
-rw-r--r--ptx/Cargo.toml1
-rw-r--r--ptx/src/lib.rs2
-rw-r--r--ptx/src/test/ops/mod.rs289
4 files changed, 359 insertions, 244 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs
index 5eed623..8e1e4e5 100644
--- a/level_zero/src/ze.rs
+++ b/level_zero/src/ze.rs
@@ -1,6 +1,7 @@
use crate::sys;
+use std::num::NonZeroUsize;
use std::{
- ffi::c_void,
+ ffi::{c_void, CStr},
fmt::{Debug, Display},
marker::PhantomData,
mem, ptr,
@@ -283,7 +284,10 @@ impl CommandQueue {
let mut raw_cmd = cmd.0;
mem::forget(cmd);
check!(sys::zeCommandQueueExecuteCommandLists(
- self.0, 1, &mut raw_cmd, result.0
+ self.0,
+ 1,
+ &mut raw_cmd,
+ result.0
));
Ok(result)
}
@@ -360,6 +364,7 @@ impl SafeRepr for f64 {}
pub struct DeviceBuffer<T: SafeRepr> {
ptr: *mut c_void,
driver: sys::ze_driver_handle_t,
+ len: usize,
marker: PhantomData<T>,
}
@@ -367,11 +372,12 @@ impl<T: SafeRepr> DeviceBuffer<T> {
pub unsafe fn as_ffi(&self) -> *mut c_void {
self.ptr
}
- pub unsafe fn from_ffi(driver: sys::ze_driver_handle_t, ptr: *mut c_void) -> Self {
+ pub unsafe fn from_ffi(driver: sys::ze_driver_handle_t, ptr: *mut c_void, len: usize) -> Self {
let marker = PhantomData::<T>;
Self {
ptr,
driver,
+ len,
marker,
}
}
@@ -392,7 +398,11 @@ impl<T: SafeRepr> DeviceBuffer<T> {
dev.0,
&mut result
));
- Ok(unsafe { Self::from_ffi(drv.0, result) })
+ Ok(unsafe { Self::from_ffi(drv.0, result, len) })
+ }
+
+ pub fn len(&self) -> usize {
+ self.len
}
}
@@ -403,28 +413,98 @@ impl<T: SafeRepr> Drop for DeviceBuffer<T> {
}
}
-pub struct CommandList(sys::ze_command_list_handle_t);
+pub struct CommandList<'a>(sys::ze_command_list_handle_t, PhantomData<&'a ()>);
-impl CommandList {
+impl<'a> CommandList<'a> {
pub unsafe fn as_ffi(&self) -> sys::ze_command_list_handle_t {
self.0
}
pub unsafe fn from_ffi(x: sys::ze_command_list_handle_t) -> Self {
- Self(x)
+ Self(x, PhantomData)
}
pub fn new(dev: &Device) -> Result<Self> {
- let desc = sys::_ze_command_list_desc_t {
+ let desc = sys::ze_command_list_desc_t {
version: sys::ze_command_list_desc_version_t::ZE_COMMAND_LIST_DESC_VERSION_CURRENT,
flags: sys::ze_command_list_flag_t::ZE_COMMAND_LIST_FLAG_NONE,
};
let mut result: sys::ze_command_list_handle_t = ptr::null_mut();
check!(sys::zeCommandListCreate(dev.0, &desc, &mut result));
- Ok(Self(result))
+ Ok(Self(result, PhantomData))
+ }
+
+ pub fn append_memory_copy<
+ T: 'a,
+ Dst: Into<BufferPtrMut<'a, T>>,
+ Src: Into<BufferPtr<'a, T>>,
+ >(
+ &mut self,
+ dst: Dst,
+ src: Src,
+ length: Option<usize>,
+ signal: Option<&Event<'a>>,
+ ) -> Result<()> {
+ let dst = dst.into();
+ let src = src.into();
+ let elements = length.unwrap_or(std::cmp::max(dst.len(), src.len()));
+ let event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
+ check!(sys::zeCommandListAppendMemoryCopy(
+ self.0,
+ dst.get(),
+ src.get(),
+ elements * std::mem::size_of::<T>(),
+ event,
+ ));
+ Ok(())
+ }
+
+ pub fn append_memory_fill<T>(
+ &mut self,
+ dst: BufferPtrMut<'a, T>,
+ pattern: T,
+ signal: Option<&Event<'a>>,
+ ) -> Result<()> {
+ let raw_pattern = &pattern as *const T as *const _;
+ let event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
+ let byte_len = dst.len() * mem::size_of::<T>();
+ check!(sys::zeCommandListAppendMemoryFill(
+ self.0,
+ dst.get(),
+ raw_pattern,
+ mem::size_of::<T>(),
+ byte_len,
+ event,
+ ));
+ Ok(())
+ }
+
+ pub fn append_launch_kernel(
+ &mut self,
+ kernel: &'a Kernel,
+ group_count: &[u32; 3],
+ signal: Option<&Event<'a>>,
+ wait: &[&Event<'a>],
+ ) -> Result<()> {
+ let gr_count = sys::ze_group_count_t {
+ groupCountX: group_count[0],
+ groupCountY: group_count[1],
+ groupCountZ: group_count[2],
+ };
+ let event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
+ let mut wait_ptrs = wait.iter().map(|e| e.0).collect::<Vec<_>>();
+ check!(sys::zeCommandListAppendLaunchKernel(
+ self.0,
+ kernel.0,
+ &gr_count,
+ event,
+ wait.len() as u32,
+ wait_ptrs.as_mut_ptr(),
+ ));
+ Ok(())
}
}
-impl Drop for CommandList {
+impl<'a> Drop for CommandList<'a> {
#[allow(unused_must_use)]
fn drop(&mut self) {
unsafe { sys::zeCommandListDestroy(self.0) };
@@ -457,3 +537,214 @@ impl<'a> Drop for FenceGuard<'a> {
unsafe { sys::zeCommandListDestroy(self.1) };
}
}
+
+#[derive(Copy, Clone)]
+pub struct BufferPtr<'a, T> {
+ ptr: *const c_void,
+ marker: PhantomData<&'a T>,
+ elems: usize,
+}
+
+impl<'a, T> BufferPtr<'a, T> {
+ pub unsafe fn get(self) -> *const c_void {
+ return self.ptr;
+ }
+
+ pub fn len(&self) -> usize {
+ self.elems
+ }
+}
+
+impl<'a, T> From<&'a [T]> for BufferPtr<'a, T> {
+ fn from(s: &'a [T]) -> Self {
+ BufferPtr {
+ ptr: s.as_ptr() as *const _,
+ marker: PhantomData,
+ elems: s.len(),
+ }
+ }
+}
+
+impl<'a, T: SafeRepr> From<&'a DeviceBuffer<T>> for BufferPtr<'a, T> {
+ fn from(b: &'a DeviceBuffer<T>) -> Self {
+ BufferPtr {
+ ptr: b.ptr as *const _,
+ marker: PhantomData,
+ elems: b.len(),
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct BufferPtrMut<'a, T> {
+ ptr: *mut c_void,
+ marker: PhantomData<&'a mut T>,
+ elems: usize,
+}
+
+impl<'a, T> BufferPtrMut<'a, T> {
+ pub unsafe fn get(self) -> *mut c_void {
+ return self.ptr;
+ }
+
+ pub fn len(&self) -> usize {
+ self.elems
+ }
+}
+
+impl<'a, T> From<&'a mut [T]> for BufferPtrMut<'a, T> {
+ fn from(s: &'a mut [T]) -> Self {
+ BufferPtrMut {
+ ptr: s.as_mut_ptr() as *mut _,
+ marker: PhantomData,
+ elems: s.len(),
+ }
+ }
+}
+
+impl<'a, T: SafeRepr> From<&'a mut DeviceBuffer<T>> for BufferPtrMut<'a, T> {
+ fn from(b: &'a mut DeviceBuffer<T>) -> Self {
+ BufferPtrMut {
+ ptr: b.ptr as *mut _,
+ marker: PhantomData,
+ elems: b.len(),
+ }
+ }
+}
+
+impl<'a, T: SafeRepr> From<BufferPtrMut<'a, T>> for BufferPtr<'a, T> {
+ fn from(b: BufferPtrMut<'a, T>) -> Self {
+ BufferPtr {
+ ptr: b.ptr,
+ marker: PhantomData,
+ elems: b.len(),
+ }
+ }
+}
+pub struct EventPool<'a>(sys::ze_event_pool_handle_t, PhantomData<&'a ()>);
+
+impl<'a> EventPool<'a> {
+ pub unsafe fn as_ffi(&self) -> sys::ze_event_pool_handle_t {
+ self.0
+ }
+ pub unsafe fn from_ffi(x: sys::ze_event_pool_handle_t) -> Self {
+ Self(x, PhantomData)
+ }
+ pub fn new(driver: &Driver, count: u32, dev: Option<&[&'a Device]>) -> Result<Self> {
+ let desc = sys::ze_event_pool_desc_t {
+ version: sys::ze_event_pool_desc_version_t::ZE_EVENT_POOL_DESC_VERSION_CURRENT,
+ flags: sys::ze_event_pool_flag_t::ZE_EVENT_POOL_FLAG_DEFAULT,
+ count: count,
+ };
+ let mut dev = dev.map(|d| d.iter().map(|d| d.0).collect::<Vec<_>>());
+ let dev_len = dev.as_ref().map_or(0, |d| d.len() as u32);
+ let dev_ptr = dev.as_mut().map_or(ptr::null_mut(), |d| d.as_mut_ptr());
+ let mut result = ptr::null_mut();
+ check!(sys::zeEventPoolCreate(
+ driver.0,
+ &desc,
+ dev_len,
+ dev_ptr,
+ &mut result
+ ));
+ Ok(Self(result, PhantomData))
+ }
+}
+
+impl<'a> Drop for EventPool<'a> {
+ #[allow(unused_must_use)]
+ fn drop(&mut self) {
+ unsafe { sys::zeEventPoolDestroy(self.0) };
+ }
+}
+
+pub struct Event<'a>(sys::ze_event_handle_t, PhantomData<&'a ()>);
+
+impl<'a> Event<'a> {
+ pub unsafe fn as_ffi(&self) -> sys::ze_event_handle_t {
+ self.0
+ }
+
+ pub unsafe fn from_ffi(x: sys::ze_event_handle_t) -> Self {
+ Self(x, PhantomData)
+ }
+
+ pub fn new(pool: &'a EventPool, index: u32) -> Result<Self> {
+ let desc = sys::ze_event_desc_t {
+ version: sys::ze_event_desc_version_t::ZE_EVENT_DESC_VERSION_CURRENT,
+ index: index,
+ signal: sys::ze_event_scope_flag_t::ZE_EVENT_SCOPE_FLAG_NONE,
+ wait: sys::ze_event_scope_flag_t::ZE_EVENT_SCOPE_FLAG_NONE,
+ };
+ let mut result = ptr::null_mut();
+ check!(sys::zeEventCreate(pool.0, &desc, &mut result));
+ Ok(Self(result, PhantomData))
+ }
+}
+
+impl<'a> Drop for Event<'a> {
+ #[allow(unused_must_use)]
+ fn drop(&mut self) {
+ unsafe { sys::zeEventDestroy(self.0) };
+ }
+}
+
+pub struct Kernel<'a>(sys::ze_kernel_handle_t, PhantomData<&'a ()>);
+
+impl<'a> Kernel<'a> {
+ pub unsafe fn as_ffi(&self) -> sys::ze_kernel_handle_t {
+ self.0
+ }
+
+ pub unsafe fn from_ffi(x: sys::ze_kernel_handle_t) -> Self {
+ Self(x, PhantomData)
+ }
+
+ pub fn new(module: &'a Module, name: &CStr) -> Result<Self> {
+ let desc = sys::ze_kernel_desc_t {
+ version: sys::ze_kernel_desc_version_t::ZE_KERNEL_DESC_VERSION_CURRENT,
+ flags: sys::ze_kernel_flag_t::ZE_KERNEL_FLAG_NONE,
+ pKernelName: name.as_ptr() as *const _,
+ };
+ let mut result = ptr::null_mut();
+ check!(sys::zeKernelCreate(module.0, &desc, &mut result));
+ Ok(Self(result, PhantomData))
+ }
+
+ pub fn set_arg_buffer<T: 'a, Buff: Into<BufferPtr<'a, T>>>(
+ &self,
+ index: u32,
+ buff: Buff,
+ ) -> Result<()> {
+ let ptr = unsafe { buff.into().get() };
+ check!(sys::zeKernelSetArgumentValue(
+ self.0,
+ index,
+ mem::size_of::<T>(),
+ &ptr as *const _ as *const _,
+ ));
+ Ok(())
+ }
+
+ pub fn set_arg_scalar<T: Copy>(&self, index: u32, value: &T) -> Result<()> {
+ check!(sys::zeKernelSetArgumentValue(
+ self.0,
+ index,
+ mem::size_of::<T>(),
+ value as *const T as *const _,
+ ));
+ Ok(())
+ }
+
+ pub fn set_group_size(&self, x: u32, y: u32, z: u32) -> Result<()> {
+ check!(sys::zeKernelSetGroupSize(self.0, x, y, z));
+ Ok(())
+ }
+}
+
+impl<'a> Drop for Kernel<'a> {
+ #[allow(unused_must_use)]
+ fn drop(&mut self) {
+ unsafe { sys::zeKernelDestroy(self.0) };
+ }
+}
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml
index 8fbc82a..02fc23b 100644
--- a/ptx/Cargo.toml
+++ b/ptx/Cargo.toml
@@ -20,4 +20,5 @@ features = ["lexer"]
[dev-dependencies]
level_zero-sys = { path = "../level_zero-sys" }
+level_zero = { path = "../level_zero" }
ocl = { version = "0.19", features = ["opencl_version_1_1", "opencl_version_1_2", "opencl_version_2_1"] }
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 022fa97..5aaaccf 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -8,6 +8,8 @@ extern crate bit_vec;
extern crate ocl;
#[cfg(test)]
extern crate level_zero_sys as l0;
+#[cfg(test)]
+extern crate level_zero as ze;
extern crate rspirv;
extern crate spirv_headers as spirv;
diff --git a/ptx/src/test/ops/mod.rs b/ptx/src/test/ops/mod.rs
index 2537cf9..85938f6 100644
--- a/ptx/src/test/ops/mod.rs
+++ b/ptx/src/test/ops/mod.rs
@@ -2,7 +2,7 @@ use crate::ptx;
use crate::translate;
use ocl::{Buffer, Context, Device, Kernel, OclPrm, Platform, Program, Queue};
use std::error;
-use std::ffi::{c_void, CString};
+use std::ffi::{c_void, CStr, CString};
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::mem;
@@ -43,7 +43,7 @@ impl<T: Display + Debug> Debug for DisplayError<T> {
impl<T: Display + Debug> error::Error for DisplayError<T> {}
-fn test_ptx_assert<'a, T: OclPrm + From<u8>>(
+fn test_ptx_assert<'a, T: OclPrm + From<u8> + ze::SafeRepr>(
name: &str,
ptx_text: &'a str,
input: &[T],
@@ -58,242 +58,38 @@ fn test_ptx_assert<'a, T: OclPrm + From<u8>>(
Ok(())
}
-fn run_spirv<T: OclPrm + From<u8>>(
+fn run_spirv<T: OclPrm + From<u8> + ze::SafeRepr>(
name: &str,
spirv: &[u32],
input: &[T],
output: &mut [T],
-) -> ocl::Result<Vec<T>> {
- let (drv, device, queue) = unsafe { l0_init() };
- let (ocl_plat, ocl_dev) = get_ocl_platform_device();
- let ocl_ctx = Context::builder()
- .platform(ocl_plat)
- .devices(ocl_dev)
- .build()?;
- let empty_cstr = CString::new("-cl-intel-greater-than-4GB-buffer-required").unwrap();
- let src = CString::new(
- "
- __kernel void ld_st(ulong a, ulong b)
- {
- __global ulong* a_copy = (__global ulong*)a;
- __global ulong* b_copy = (__global ulong*)b;
- *b_copy = *a_copy;
- }",
- )?;
- let prog = Program::with_source(&ocl_ctx, &[src], None, &empty_cstr)?;
- let binaries_wrapped = prog.info(ocl::core::ProgramInfo::Binaries)?;
- let binaries = if let ocl::core::ProgramInfoResult::Binaries(bins) = binaries_wrapped {
- bins
- } else {
- panic!()
- };
- let module = l0_create_module(device, &binaries[0]);
- let kernel_desc = l0::ze_kernel_desc_t {
- version: l0::ze_kernel_desc_version_t::ZE_KERNEL_DESC_VERSION_CURRENT,
- flags: l0::ze_kernel_flag_t::ZE_KERNEL_FLAG_NONE,
- pKernelName: "ld_st".as_ptr() as *const _,
- };
- let mut kernel: l0::ze_kernel_handle_t = ptr::null_mut();
- let mut err = unsafe { l0::zeKernelCreate(module, &kernel_desc, &mut kernel) };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- let inp_b = l0_allocate_buffer(drv, device, &input);
- let out_b = l0_allocate_buffer(drv, device, &output);
- println!("inp_b: {:?}", inp_b);
- println!("out_b: {:?}", out_b);
- let mut cmd_list = l0_create_cmd_list(device);
- println!("input: {:?}", input);
- err = unsafe {
- l0::zeCommandListAppendMemoryCopy(
- cmd_list,
- inp_b,
- input.as_ptr() as *const _,
- input.len() * mem::size_of::<T>(),
- ptr::null_mut(),
- )
- };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- let pattern = 0u8;
- err = unsafe {
- l0::zeCommandListAppendMemoryFill(
- cmd_list,
- out_b,
- &pattern as *const u8 as *const _,
- 1,
- input.len() * mem::size_of::<T>(),
- ptr::null_mut(),
- )
- };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- err = unsafe { l0::zeKernelSetGroupSize(kernel, 1, 1, 1) };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- let wg_size = l0::ze_group_count_t {
- groupCountX: 1,
- groupCountY: 1,
- groupCountZ: 1,
- };
- err = unsafe {
- l0::zeKernelSetArgumentValue(
- kernel,
- 0,
- mem::size_of::<*mut c_void>(),
- &inp_b as *const *mut _ as *const _,
- )
- };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- err = unsafe {
- l0::zeKernelSetArgumentValue(
- kernel,
- 1,
- mem::size_of::<*mut c_void>(),
- &out_b as *const *mut _ as *const _,
- )
- };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- err = unsafe {
- l0::zeCommandListAppendBarrier(
- cmd_list,
- ptr::null_mut(),
- 0,
- ptr::null_mut(),
- )
- };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- err = unsafe {
- l0::zeCommandListAppendLaunchKernel(
- cmd_list,
- kernel,
- &wg_size,
- ptr::null_mut(),
- 0,
- ptr::null_mut(),
- )
- };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- err = unsafe {
- l0::zeCommandListAppendBarrier(
- cmd_list,
- ptr::null_mut(),
- 0,
- ptr::null_mut(),
- )
- };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- let mut result: Vec<T> = vec![0u8.into(); output.len()];
- err = unsafe {
- l0::zeCommandListAppendMemoryCopy(
- cmd_list,
- result.as_mut_ptr() as *mut _,
- out_b,
- result.len() * mem::size_of::<T>(),
- ptr::null_mut(),
- )
- };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- err = unsafe { l0::zeCommandListClose(cmd_list) };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- err =
- unsafe { l0::zeCommandQueueExecuteCommandLists(queue, 1, &mut cmd_list, ptr::null_mut()) };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- err = unsafe { l0::zeCommandQueueSynchronize(queue, u32::max_value()) };
- assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
- /*
- let (plat, dev) = get_ocl_platform_device();
- let ctx = Context::builder().platform(plat).devices(dev).build()?;
- let empty_cstr = CString::new("-cl-intel-greater-than-4GB-buffer-required").unwrap();
- let byte_il = unsafe {
- slice::from_raw_parts::<u8>(
- spirv.as_ptr() as *const _,
- spirv.len() * mem::size_of::<u32>(),
- )
- };
- let src = CString::new(
- "
- __kernel void ld_st(ulong a, ulong b)
- {
- __global ulong* a_copy = (__global ulong*)a;
- __global ulong* b_copy = (__global ulong*)b;
- *b_copy = *a_copy;
- }",
- )
- .unwrap();
- //let prog = Program::with_il(byte_il, Some(&[dev]), &empty_cstr, &ctx)?;
- let prog = Program::with_source(&ctx, &[src], Some(&[dev]), &empty_cstr)?;
- let queue = Queue::new(&ctx, dev, None)?;
- let cl_device_mem_alloc_intel = get_cl_device_mem_alloc_intel(&plat)?;
- let cl_enqueue_memcpy_intel = get_cl_enqueue_memcpy_intel(&plat)?;
- let cl_enqueue_memset_intel = get_cl_enqueue_memset_intel(&plat)?;
- let cl_set_kernel_arg_mem_pointer_intel = get_cl_set_kernel_arg_mem_pointer_intel(&plat)?;
- let mut err_code = 0;
- let inp_b = cl_device_mem_alloc_intel(
- ctx.as_ptr(),
- dev.as_raw(),
- ptr::null_mut(),
- input.len() * mem::size_of::<T>(),
- mem::align_of::<T>() as u32,
- &mut err_code,
- );
- assert_eq!(err_code, 0);
- let out_b = cl_device_mem_alloc_intel(
- ctx.as_ptr(),
- dev.as_raw(),
- ptr::null_mut(),
- output.len() * mem::size_of::<T>(),
- mem::align_of::<T>() as u32,
- &mut err_code,
- );
- assert_eq!(err_code, 0);
- err_code = cl_enqueue_memcpy_intel(
- queue.as_ptr(),
- 1,
- inp_b as *mut _,
- input.as_ptr() as *const _,
- input.len() * mem::size_of::<T>(),
- 0,
- ptr::null(),
- ptr::null_mut(),
- );
- assert_eq!(err_code, 0);
- err_code = cl_enqueue_memset_intel(
- queue.as_ptr(),
- out_b as *mut _,
- 0,
- input.len() * mem::size_of::<T>(),
- 0,
- ptr::null(),
- ptr::null_mut(),
- );
- assert_eq!(err_code, 0);
- let kernel = ocl::core::create_kernel(prog.as_core(), name)?;
- err_code = cl_set_kernel_arg_mem_pointer_intel(kernel.as_ptr(), 0, inp_b);
- assert_eq!(err_code, 0);
- err_code = cl_set_kernel_arg_mem_pointer_intel(kernel.as_ptr(), 1, out_b);
- assert_eq!(err_code, 0);
- unsafe {
- ocl::core::enqueue_kernel::<(), ()>(
- queue.as_core(),
- &kernel,
- 1,
- None,
- &[1, 0, 0],
- None,
- None,
- None,
- )
- }?;
- let mut result: Vec<T> = vec![0u8.into(); output.len()];
- err_code = cl_enqueue_memcpy_intel(
- queue.as_ptr(),
- 1,
- result.as_mut_ptr() as *mut _,
- inp_b,
- result.len() * mem::size_of::<T>(),
- 0,
- ptr::null(),
- ptr::null_mut(),
- );
- assert_eq!(err_code, 0);
- queue.finish()?;
- */
+) -> ze::Result<Vec<T>> {
+ ze::init()?;
+ let mut result = vec![0u8.into(); output.len()];
+ let mut drivers = ze::Driver::get()?;
+ let drv = drivers.drain(0..1).next().unwrap();
+ let mut devices = drv.devices()?;
+ let dev = devices.drain(0..1).next().unwrap();
+ let queue = ze::CommandQueue::new(&dev)?;
+ let native_bins = get_program_native().unwrap();
+ let module = ze::Module::new_native(&dev, &native_bins[0])?;
+ let kernel = ze::Kernel::new(&module, CStr::from_bytes_with_nul(b"ld_st\0").unwrap())?;
+ let mut inp_b = ze::DeviceBuffer::<T>::new(&drv, &dev, input.len())?;
+ let mut out_b = ze::DeviceBuffer::<T>::new(&drv, &dev, output.len())?;
+ let inp_b_ptr_mut: ze::BufferPtrMut<T> = (&mut inp_b).into();
+ let event_pool = ze::EventPool::new(&drv, 3, Some(&[&dev]))?;
+ let ev0 = ze::Event::new(&event_pool, 0)?;
+ let ev1 = ze::Event::new(&event_pool, 1)?;
+ let mut cmd_list = ze::CommandList::new(&dev)?;
+ let out_b_ptr: ze::BufferPtrMut<T> = (&mut out_b).into();
+ cmd_list.append_memory_copy(inp_b_ptr_mut, input, None, Some(&ev0))?;
+ cmd_list.append_memory_fill(out_b_ptr, 0u8.into(), Some(&ev1))?;
+ kernel.set_group_size(1, 1, 1)?;
+ kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
+ kernel.set_arg_buffer(1, out_b_ptr)?;
+ cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], None, &[&ev0, &ev1])?;
+ cmd_list.append_memory_copy(result.as_mut_slice(), inp_b_ptr_mut, None, Some(&ev0))?;
+ queue.execute(cmd_list)?;
Ok(result)
}
@@ -491,3 +287,28 @@ fn l0_create_cmd_list(dev: l0::ze_device_handle_t) -> l0::ze_command_list_handle
assert_eq!(err, l0::ze_result_t::ZE_RESULT_SUCCESS);
result
}
+
+fn get_program_native() -> ocl::Result<Vec<Vec<u8>>> {
+ let (ocl_plat, ocl_dev) = get_ocl_platform_device();
+ let ocl_ctx = Context::builder()
+ .platform(ocl_plat)
+ .devices(ocl_dev)
+ .build()?;
+ let empty_cstr = CString::new("-cl-intel-greater-than-4GB-buffer-required").unwrap();
+ let src = CString::new(
+ "
+ __kernel void ld_st(ulong a, ulong b)
+ {
+ __global ulong* a_copy = (__global ulong*)a;
+ __global ulong* b_copy = (__global ulong*)b;
+ *b_copy = *a_copy;
+ }",
+ )?;
+ let prog = Program::with_source(&ocl_ctx, &[src], None, &empty_cstr)?;
+ let binaries_wrapped = prog.info(ocl::core::ProgramInfo::Binaries)?;
+ if let ocl::core::ProgramInfoResult::Binaries(bins) = binaries_wrapped {
+ Ok(bins)
+ } else {
+ panic!()
+ }
+}