aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-05-27 02:05:17 +0200
committerAndrzej Janik <[email protected]>2021-05-27 02:05:17 +0200
commite40785aa7491de16c65de7aa599105102ffa7355 (patch)
tree87b4b16dbf6318aae8456a04ab6af574d2238ddb
parent58a7fe53c6feaf96156c455b7c3b1def9d7e6d56 (diff)
downloadZLUDA-e40785aa7491de16c65de7aa599105102ffa7355.tar.gz
ZLUDA-e40785aa7491de16c65de7aa599105102ffa7355.zip
Refactor L0 bindings
-rw-r--r--level_zero/src/ze.rs804
-rw-r--r--ptx/src/test/spirv_run/mod.rs63
-rw-r--r--zluda/src/impl/context.rs6
-rw-r--r--zluda/src/impl/device.rs38
-rw-r--r--zluda/src/impl/function.rs4
-rw-r--r--zluda/src/impl/memory.rs18
-rw-r--r--zluda/src/impl/module.rs20
-rw-r--r--zluda/src/impl/stream.rs8
-rw-r--r--zluda_ml/src/impl.rs3
9 files changed, 561 insertions, 403 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs
index d2b1115..88adfe6 100644
--- a/level_zero/src/ze.rs
+++ b/level_zero/src/ze.rs
@@ -1,11 +1,37 @@
+use sys::zeFenceDestroy;
+
use crate::sys;
use std::{
ffi::{c_void, CStr, CString},
fmt::Debug,
marker::PhantomData,
- mem, ptr,
+ mem,
+ ptr::{self, NonNull},
};
+/*
+ This module is not a user-friendly, safe binding. The problem is tracking
+ object lifetimes. E.g. kernel object cannot outlive module object.
+ While Rust is relatively good at it, it's tricky to translate it to a safe
+ API in a way that we can mix and match them, but here's I'd sketch it:
+ - There's no &mut references: all API operations copy data in and out
+ - All baseline objects are Send, but not Sync
+ - There are some problems with using "naked" Rc and Arc:
+ - We should not allow users to create Rc by themselves without including
+ parent pointer
+ - We should not allow DerefMut in Mutex and moving out of it
+ - Objects are wrapped in Rc<ZeCell<_>> and Arc<ZeMutex<_>>, parent
+ pointer is part of ZeCell/ZeMutex:
+ - Then e.g. zeKernelCreate is mapped three times:
+ - unsafe Module(&self) -> Kernel
+ - Module(&Rc<ZeCell<Module>>) -> Rc<ZeCell<Kernel>>
+ - Module(&Arc<ZeMutex<Module>>) -> Arc<KernelMutex>
+ - You create ZeCell<Module> by moving Module and Rc<ZeCell<Context>
+ - Pro: Rc and Arc are allowed to be self receivers
+ - Open question: should some operations take the parent mutex? If so, should
+ it be done recursively?
+*/
+
macro_rules! check {
($expr:expr) => {
#[allow(unused_unsafe)]
@@ -39,102 +65,155 @@ pub fn init() -> Result<()> {
}
}
+// Mutability: no (list of allocations is under a mutex)
+// Lifetime: 'static
#[repr(transparent)]
-pub struct Driver(sys::ze_driver_handle_t);
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub struct Driver(NonNull<sys::_ze_driver_handle_t>);
unsafe impl Send for Driver {}
unsafe impl Sync for Driver {}
impl Driver {
- pub unsafe fn as_ffi(&self) -> sys::ze_driver_handle_t {
- self.0
+ pub unsafe fn as_ffi(self) -> sys::ze_driver_handle_t {
+ self.0.as_ptr()
}
pub unsafe fn from_ffi(x: sys::ze_driver_handle_t) -> Self {
- Self(x)
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x))
}
pub fn get() -> Result<Vec<Self>> {
let mut len = 0;
let mut temp = ptr::null_mut();
check!(sys::zeDriverGet(&mut len, &mut temp));
- let mut result = (0..len)
- .map(|_| Driver(ptr::null_mut()))
- .collect::<Vec<_>>();
+ let mut result = Vec::with_capacity(len as usize);
check!(sys::zeDriverGet(&mut len, result.as_mut_ptr() as *mut _));
+ unsafe {
+ result.set_len(len as usize);
+ }
Ok(result)
}
- pub fn devices(&self) -> Result<Vec<Device>> {
+ pub fn devices(self) -> Result<Vec<Device>> {
let mut len = 0;
let mut temp = ptr::null_mut();
- check!(sys::zeDeviceGet(self.0, &mut len, &mut temp));
- let mut result = (0..len)
- .map(|_| Device(ptr::null_mut()))
- .collect::<Vec<_>>();
+ check!(sys::zeDeviceGet(self.as_ffi(), &mut len, &mut temp));
+ let mut result = Vec::with_capacity(len as usize);
check!(sys::zeDeviceGet(
- self.0,
+ self.as_ffi(),
&mut len,
result.as_mut_ptr() as *mut _
));
- if (len as usize) < result.len() {
- result.truncate(len as usize);
+ unsafe {
+ result.set_len(len as usize);
}
Ok(result)
}
- pub fn get_properties(&self) -> Result<sys::ze_driver_properties_t> {
- let mut result = unsafe { mem::zeroed::<sys::ze_driver_properties_t>() };
- check!(sys::zeDriverGetProperties(self.0, &mut result));
- Ok(result)
+ pub fn get_properties(self, props: &mut sys::ze_driver_properties_t) -> Result<()> {
+ check!(sys::zeDriverGetProperties(self.as_ffi(), props));
+ Ok(())
}
}
+// Mutability: no (list of peer allocations under a mutex)
+// Lifetime: 'static
#[repr(transparent)]
-pub struct Device(sys::ze_device_handle_t);
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub struct Device(NonNull<sys::_ze_device_handle_t>);
+
+unsafe impl Send for Device {}
+unsafe impl Sync for Device {}
impl Device {
- pub unsafe fn as_ffi(&self) -> sys::ze_device_handle_t {
- self.0
+ pub unsafe fn as_ffi(self) -> sys::ze_device_handle_t {
+ self.0.as_ptr()
}
pub unsafe fn from_ffi(x: sys::ze_device_handle_t) -> Self {
- Self(x)
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x))
}
- pub fn get_properties(&self) -> Result<Box<sys::ze_device_properties_t>> {
- let mut props = Box::new(unsafe { mem::zeroed::<sys::ze_device_properties_t>() });
- check! { sys::zeDeviceGetProperties(self.0, props.as_mut()) };
- Ok(props)
+ pub fn get_properties(self, props: &mut sys::ze_device_properties_t) -> Result<()> {
+ check! { sys::zeDeviceGetProperties(self.as_ffi(), props) };
+ Ok(())
}
- pub fn get_image_properties(&self) -> Result<Box<sys::ze_device_image_properties_t>> {
- let mut props = Box::new(unsafe { mem::zeroed::<sys::ze_device_image_properties_t>() });
- check! { sys::zeDeviceGetImageProperties(self.0, props.as_mut()) };
- Ok(props)
+ pub fn get_image_properties(self, props: &mut sys::ze_device_image_properties_t) -> Result<()> {
+ check! { sys::zeDeviceGetImageProperties(self.as_ffi(), props) };
+ Ok(())
}
- pub fn get_memory_properties(&self) -> Result<Vec<sys::ze_device_memory_properties_t>> {
+ pub fn get_memory_properties(self) -> Result<Vec<sys::ze_device_memory_properties_t>> {
let mut count = 0u32;
- check! { sys::zeDeviceGetMemoryProperties(self.0, &mut count, ptr::null_mut()) };
+ check! { sys::zeDeviceGetMemoryProperties(self.as_ffi(), &mut count, ptr::null_mut()) };
if count == 0 {
return Ok(Vec::new());
}
let mut props =
vec![unsafe { mem::zeroed::<sys::ze_device_memory_properties_t>() }; count as usize];
- check! { sys::zeDeviceGetMemoryProperties(self.0, &mut count, props.as_mut_ptr()) };
+ check! { sys::zeDeviceGetMemoryProperties(self.as_ffi(), &mut count, props.as_mut_ptr()) };
Ok(props)
}
- pub fn get_compute_properties(&self) -> Result<Box<sys::ze_device_compute_properties_t>> {
- let mut props = Box::new(unsafe { mem::zeroed::<sys::ze_device_compute_properties_t>() });
- check! { sys::zeDeviceGetComputeProperties(self.0, props.as_mut()) };
- Ok(props)
+ pub fn get_compute_properties(
+ self,
+ props: &mut sys::ze_device_compute_properties_t,
+ ) -> Result<()> {
+ check! { sys::zeDeviceGetComputeProperties(self.as_ffi(), props) };
+ Ok(())
+ }
+}
+
+// Mutability: no
+#[repr(transparent)]
+pub struct Context(NonNull<sys::_ze_context_handle_t>);
+
+unsafe impl Send for Context {}
+unsafe impl Sync for Context {}
+
+impl Context {
+ pub unsafe fn as_ffi(&self) -> sys::ze_context_handle_t {
+ self.0.as_ptr()
+ }
+ pub unsafe fn from_ffi(x: sys::ze_context_handle_t) -> Self {
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x))
}
- pub unsafe fn mem_alloc_device(
- &mut self,
- ctx: &mut Context,
+ pub fn new(drv: Driver, devices: Option<&[Device]>) -> Result<Self> {
+ let ctx_desc = sys::ze_context_desc_t {
+ stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_CONTEXT_DESC,
+ pNext: ptr::null(),
+ flags: sys::ze_context_flags_t(0),
+ };
+ let mut result = ptr::null_mut();
+ let (dev_ptr, dev_len) = match devices {
+ None => (ptr::null(), 0),
+ Some(devs) => (devs.as_ptr(), devs.len()),
+ };
+ check!(sys::zeContextCreateEx(
+ drv.as_ffi(),
+ &ctx_desc,
+ dev_len as u32,
+ dev_ptr as *mut _,
+ &mut result
+ ));
+ Ok(unsafe { Self::from_ffi(result) })
+ }
+
+ pub fn mem_alloc_device(
+ &self,
size: usize,
alignment: usize,
+ device: Device,
) -> Result<*mut c_void> {
let descr = sys::ze_device_mem_alloc_desc_t {
stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
@@ -143,47 +222,24 @@ impl Device {
ordinal: 0,
};
let mut result = ptr::null_mut();
- // TODO: check current context for the device
check! {
sys::zeMemAllocDevice(
- ctx.0,
+ self.as_ffi(),
&descr,
size,
alignment,
- self.0,
+ device.as_ffi(),
&mut result,
)
};
Ok(result)
}
-}
-#[repr(transparent)]
-pub struct Context(sys::ze_context_handle_t);
-
-impl Context {
- pub unsafe fn as_ffi(&self) -> sys::ze_context_handle_t {
- self.0
- }
- pub unsafe fn from_ffi(x: sys::ze_context_handle_t) -> Self {
- Self(x)
- }
-
- pub fn new(drv: &Driver) -> Result<Self> {
- let ctx_desc = sys::ze_context_desc_t {
- stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_CONTEXT_DESC,
- pNext: ptr::null(),
- flags: sys::ze_context_flags_t(0),
- };
- let mut result = ptr::null_mut();
- check!(sys::zeContextCreate(drv.0, &ctx_desc, &mut result));
- Ok(Context(result))
- }
-
- pub unsafe fn mem_free(&mut self, ptr: *mut c_void) -> Result<()> {
+ // This operation is safe because Level Zero impl tracks allocations
+ pub fn mem_free(&self, ptr: *mut c_void) -> Result<()> {
check! {
sys::zeMemFree(
- self.0,
+ self.as_ffi(),
ptr,
)
};
@@ -194,22 +250,32 @@ impl Context {
impl Drop for Context {
#[allow(unused_must_use)]
fn drop(&mut self) {
- check_panic! { sys::zeContextDestroy(self.0) };
+ check_panic! { sys::zeContextDestroy(self.as_ffi()) };
}
}
+// Mutability: yes (residency container and others)
+// Lifetime parent: Context
#[repr(transparent)]
-pub struct CommandQueue(sys::ze_command_queue_handle_t);
+pub struct CommandQueue<'a>(
+ NonNull<sys::_ze_command_queue_handle_t>,
+ PhantomData<&'a ()>,
+);
+
+unsafe impl<'a> Send for CommandQueue<'a> {}
-impl CommandQueue {
+impl<'a> CommandQueue<'a> {
pub unsafe fn as_ffi(&self) -> sys::ze_command_queue_handle_t {
- self.0
+ self.0.as_ptr()
}
pub unsafe fn from_ffi(x: sys::ze_command_queue_handle_t) -> Self {
- Self(x)
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x), PhantomData)
}
- pub fn new(ctx: &mut Context, d: &Device) -> Result<Self> {
+ pub fn new(ctx: &'a Context, d: Device) -> Result<Self> {
let que_desc = sys::ze_command_queue_desc_t {
stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
pNext: ptr::null(),
@@ -221,48 +287,138 @@ impl CommandQueue {
};
let mut result = ptr::null_mut();
check!(sys::zeCommandQueueCreate(
- ctx.0,
- d.0,
+ ctx.as_ffi(),
+ d.as_ffi(),
&que_desc,
&mut result
));
- Ok(CommandQueue(result))
+ Ok(unsafe { Self::from_ffi(result) })
+ }
+
+ pub fn execute_and_synchronize<'cmd_list>(
+ &'a self,
+ cmd: CommandList<'cmd_list>,
+ ) -> Result<FenceGuard<'cmd_list>>
+ where
+ 'a: 'cmd_list,
+ {
+ let fence_guard = FenceGuard::new(self, cmd)?;
+ unsafe { self.execute(&fence_guard.1, Some(&fence_guard.0))? };
+ Ok(fence_guard)
}
- pub fn execute<'a>(&'a self, cmd: CommandList) -> Result<FenceGuard<'a>> {
- check!(sys::zeCommandListClose(cmd.0));
- let result = FenceGuard::new(self, cmd.0)?;
- let mut raw_cmd = cmd.0;
- mem::forget(cmd);
+ pub unsafe fn execute<'cmd_list, 'fence>(
+ &self,
+ cmd: &CommandList<'cmd_list>,
+ fence: Option<&Fence<'fence>>,
+ ) -> Result<()>
+ where
+ 'cmd_list: 'fence,
+ 'a: 'cmd_list,
+ {
+ let fence_ptr = fence.map_or(ptr::null_mut(), |f| f.as_ffi());
check!(sys::zeCommandQueueExecuteCommandLists(
- self.0,
+ self.as_ffi(),
1,
- &mut raw_cmd,
- result.0
+ &mut cmd.as_ffi(),
+ fence_ptr
));
- Ok(result)
+ Ok(())
}
}
-impl Drop for CommandQueue {
+impl<'a> Drop for CommandQueue<'a> {
#[allow(unused_must_use)]
fn drop(&mut self) {
- check_panic! { sys::zeCommandQueueDestroy(self.0) };
+ check_panic! { sys::zeCommandQueueDestroy(self.as_ffi()) };
+ }
+}
+
+pub struct FenceGuard<'a>(Fence<'a>, CommandList<'a>);
+
+impl<'a> FenceGuard<'a> {
+ fn new(q: &'a CommandQueue, cmd_list: CommandList<'a>) -> Result<Self> {
+ Ok(FenceGuard(Fence::new(q)?, cmd_list))
+ }
+}
+
+impl<'a> Drop for FenceGuard<'a> {
+ #[allow(unused_must_use)]
+ fn drop(&mut self) {
+ if let Err(e) = self.0.host_synchronize() {
+ panic!(e)
+ }
+ }
+}
+
+// Mutability: yes (reset)
+// Lifetime parent: queue
+#[repr(transparent)]
+pub struct Fence<'a>(NonNull<sys::_ze_fence_handle_t>, PhantomData<&'a ()>);
+
+unsafe impl<'a> Send for Fence<'a> {}
+
+impl<'a> Fence<'a> {
+ pub unsafe fn as_ffi(&self) -> sys::ze_fence_handle_t {
+ self.0.as_ptr()
+ }
+ pub unsafe fn from_ffi(x: sys::ze_fence_handle_t) -> Self {
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x), PhantomData)
+ }
+
+ pub fn new(queue: &'a CommandQueue) -> Result<Self> {
+ let desc = sys::_ze_fence_desc_t {
+ stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_FENCE_DESC,
+ pNext: ptr::null(),
+ flags: sys::ze_fence_flags_t(0),
+ };
+ let mut result = ptr::null_mut();
+ check!(sys::zeFenceCreate(queue.as_ffi(), &desc, &mut result));
+ Ok(unsafe { Self::from_ffi(result) })
+ }
+
+ pub fn host_synchronize(&self) -> Result<()> {
+ check!(sys::zeFenceHostSynchronize(self.as_ffi(), u64::max_value()));
+ Ok(())
+ }
+}
+
+impl<'a> Drop for Fence<'a> {
+ fn drop(&mut self) {
+ check_panic! { zeFenceDestroy(self.as_ffi()) };
}
}
-pub struct Module(sys::ze_module_handle_t);
+// Mutability: yes (building, linking)
+// Lifetime parent: Context
+#[repr(transparent)]
+pub struct Module<'a>(NonNull<sys::_ze_module_handle_t>, PhantomData<&'a ()>);
+
+unsafe impl<'a> Send for Module<'a> {}
+
+impl<'a> Module<'a> {
+ pub unsafe fn as_ffi(&self) -> sys::ze_module_handle_t {
+ self.0.as_ptr()
+ }
+ pub unsafe fn from_ffi(x: sys::ze_module_handle_t) -> Self {
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x), PhantomData)
+ }
-impl Module {
// HACK ALERT
// We use OpenCL for now to do SPIR-V linking, because Level0
// does not allow linking. Don't let presence of zeModuleDynamicLink fool
// you, it's not currently possible to create non-compiled modules.
// zeModuleCreate always compiles (builds and links).
- pub fn build_link_spirv<'a>(
- ctx: &mut Context,
- d: &Device,
- binaries: &[&'a [u8]],
+ pub fn build_link_spirv<'buffers>(
+ ctx: &'a Context,
+ d: Device,
+ binaries: &[&'buffers [u8]],
opts: Option<&CStr>,
) -> (Result<Self>, Option<BuildLog>) {
let ocl_program = match Self::build_link_spirv_impl(binaries, opts) {
@@ -283,8 +439,8 @@ impl Module {
}
}
- fn build_link_spirv_impl<'a>(
- binaries: &[&'a [u8]],
+ fn build_link_spirv_impl<'buffers>(
+ binaries: &[&'buffers [u8]],
opts: Option<&CStr>,
) -> ocl_core::Result<ocl_core::Program> {
let platforms = ocl_core::get_platform_ids()?;
@@ -348,8 +504,8 @@ impl Module {
}
pub fn build_spirv(
- ctx: &mut Context,
- d: &Device,
+ ctx: &'a Context,
+ d: Device,
bin: &[u8],
opts: Option<&CStr>,
) -> Result<Self> {
@@ -357,8 +513,8 @@ impl Module {
}
pub fn build_spirv_logged(
- ctx: &mut Context,
- d: &Device,
+ ctx: &'a Context,
+ d: Device,
bin: &[u8],
opts: Option<&CStr>,
) -> (Result<Self>, BuildLog) {
@@ -366,17 +522,17 @@ impl Module {
}
pub fn build_native_logged(
- ctx: &mut Context,
- d: &Device,
+ ctx: &'a Context,
+ d: Device,
bin: &[u8],
) -> (Result<Self>, BuildLog) {
Module::new_logged(ctx, false, d, bin, None)
}
fn new(
- ctx: &mut Context,
+ ctx: &'a Context,
spirv: bool,
- d: &Device,
+ d: Device,
bin: &[u8],
opts: Option<&CStr>,
) -> Result<Self> {
@@ -394,18 +550,22 @@ impl Module {
pConstants: ptr::null(),
};
let mut result: sys::ze_module_handle_t = ptr::null_mut();
- let err = unsafe { sys::zeModuleCreate(ctx.0, d.0, &desc, &mut result, ptr::null_mut()) };
- if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS {
- Result::Err(err)
- } else {
- Ok(Module(result))
- }
+ check! {
+ sys::zeModuleCreate(
+ ctx.as_ffi(),
+ d.as_ffi(),
+ &desc,
+ &mut result,
+ ptr::null_mut(),
+ )
+ };
+ Ok(unsafe { Self::from_ffi(result) })
}
fn new_logged(
- ctx: &mut Context,
+ ctx: &'a Context,
spirv: bool,
- d: &Device,
+ d: Device,
bin: &[u8],
opts: Option<&CStr>,
) -> (Result<Self>, BuildLog) {
@@ -424,74 +584,83 @@ impl Module {
};
let mut result: sys::ze_module_handle_t = ptr::null_mut();
let mut log_handle = ptr::null_mut();
- let err = unsafe { sys::zeModuleCreate(ctx.0, d.0, &desc, &mut result, &mut log_handle) };
- let log = BuildLog(log_handle);
- if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS {
+ let err = unsafe {
+ sys::zeModuleCreate(
+ ctx.as_ffi(),
+ d.as_ffi(),
+ &desc,
+ &mut result,
+ &mut log_handle,
+ )
+ };
+ let log = unsafe { BuildLog::from_ffi(log_handle) };
+ if err != sys::ze_result_t::ZE_RESULT_SUCCESS {
(Result::Err(err), log)
} else {
- (Ok(Module(result)), log)
+ (Ok(unsafe { Self::from_ffi(result) }), log)
}
}
}
-impl Drop for Module {
+impl<'a> Drop for Module<'a> {
#[allow(unused_must_use)]
fn drop(&mut self) {
- check_panic! { sys::zeModuleDestroy(self.0) };
+ check_panic! { sys::zeModuleDestroy(self.as_ffi()) };
}
}
-pub struct BuildLog(sys::ze_module_build_log_handle_t);
+// Mutability: none
+// Lifetime parent: none, but need to destroy
+pub struct BuildLog(NonNull<sys::_ze_module_build_log_handle_t>);
+
+unsafe impl Sync for BuildLog {}
+unsafe impl Send for BuildLog {}
impl BuildLog {
pub unsafe fn as_ffi(&self) -> sys::ze_module_build_log_handle_t {
- self.0
+ self.0.as_ptr()
}
pub unsafe fn from_ffi(x: sys::ze_module_build_log_handle_t) -> Self {
- Self(x)
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x))
}
- pub fn get_cstring(&self) -> Result<CString> {
+ pub fn to_cstring(&self) -> Result<CString> {
let mut size = 0;
- check! { sys::zeModuleBuildLogGetString(self.0, &mut size, ptr::null_mut()) };
+ check! { sys::zeModuleBuildLogGetString(self.as_ffi(), &mut size, ptr::null_mut()) };
let mut str_vec = vec![0u8; size];
- check! { sys::zeModuleBuildLogGetString(self.0, &mut size, str_vec.as_mut_ptr() as *mut i8) };
- str_vec.pop();
- Ok(CString::new(str_vec).map_err(|_| sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?)
+ check! { sys::zeModuleBuildLogGetString(self.as_ffi(), &mut size, str_vec.as_mut_ptr() as *mut i8) };
+ str_vec.push(0);
+ Ok(unsafe { CString::from_vec_unchecked(str_vec) })
}
}
impl Drop for BuildLog {
fn drop(&mut self) {
- check_panic!(sys::zeModuleBuildLogDestroy(self.0));
+ check_panic!(sys::zeModuleBuildLogDestroy(self.as_ffi()));
}
}
-pub trait SafeRepr {}
-impl SafeRepr for u8 {}
-impl SafeRepr for i8 {}
-impl SafeRepr for u16 {}
-impl SafeRepr for i16 {}
-impl SafeRepr for u32 {}
-impl SafeRepr for i32 {}
-impl SafeRepr for u64 {}
-impl SafeRepr for i64 {}
-impl SafeRepr for f32 {}
-impl SafeRepr for f64 {}
-
-pub struct DeviceBuffer<T: SafeRepr> {
+// Mutability: none
+// Lifetime parent: Context
+pub struct DeviceBuffer<'a, T: Copy> {
ptr: *mut c_void,
ctx: sys::ze_context_handle_t,
len: usize,
- marker: PhantomData<T>,
+ marker: PhantomData<&'a T>,
}
-impl<T: SafeRepr> DeviceBuffer<T> {
- pub unsafe fn as_ffi(&self) -> *mut c_void {
- self.ptr
+unsafe impl<'a, T: Copy> Sync for DeviceBuffer<'a, T> {}
+unsafe impl<'a, T: Copy> Send for DeviceBuffer<'a, T> {}
+
+impl<'a, T: Copy> DeviceBuffer<'a, T> {
+ pub unsafe fn as_ffi(&self) -> (sys::ze_context_handle_t, *mut c_void, usize) {
+ (self.ctx, self.ptr, self.len)
}
pub unsafe fn from_ffi(ctx: sys::ze_context_handle_t, ptr: *mut c_void, len: usize) -> Self {
- let marker = PhantomData::<T>;
+ let marker = PhantomData::<&'a T>;
Self {
ptr,
ctx,
@@ -500,7 +669,7 @@ impl<T: SafeRepr> DeviceBuffer<T> {
}
}
- pub fn new(ctx: &mut Context, dev: &Device, len: usize) -> Result<Self> {
+ pub fn new(ctx: &'a Context, dev: Device, len: usize) -> Result<Self> {
let desc = sys::_ze_device_mem_alloc_desc_t {
stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
pNext: ptr::null(),
@@ -509,39 +678,49 @@ impl<T: SafeRepr> DeviceBuffer<T> {
};
let mut result = ptr::null_mut();
check!(sys::zeMemAllocDevice(
- ctx.0,
+ ctx.as_ffi(),
&desc,
len * mem::size_of::<T>(),
mem::align_of::<T>(),
- dev.0,
+ dev.as_ffi(),
&mut result
));
- Ok(unsafe { Self::from_ffi(ctx.0, result, len) })
+ Ok(unsafe { Self::from_ffi(ctx.as_ffi(), result, len) })
}
pub fn len(&self) -> usize {
self.len
}
+
+ pub fn data(&self) -> *mut c_void {
+ self.ptr
+ }
}
-impl<T: SafeRepr> Drop for DeviceBuffer<T> {
- #[allow(unused_must_use)]
+impl<'a, T: Copy> Drop for DeviceBuffer<'a, T> {
fn drop(&mut self) {
check_panic! { sys::zeMemFree(self.ctx, self.ptr) };
}
}
-pub struct CommandList<'a>(sys::ze_command_list_handle_t, PhantomData<&'a ()>);
+// Mutability: yes (appends)
+// Lifetime parent: Context
+pub struct CommandList<'a>(NonNull<sys::_ze_command_list_handle_t>, PhantomData<&'a ()>);
+
+unsafe impl<'a> Send for CommandList<'a> {}
impl<'a> CommandList<'a> {
pub unsafe fn as_ffi(&self) -> sys::ze_command_list_handle_t {
- self.0
+ self.0.as_ptr()
}
pub unsafe fn from_ffi(x: sys::ze_command_list_handle_t) -> Self {
- Self(x, PhantomData)
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x), PhantomData)
}
- pub fn new(ctx: &mut Context, dev: &Device) -> Result<Self> {
+ pub fn new(ctx: &'a Context, dev: Device) -> Result<Self> {
let desc = sys::ze_command_list_desc_t {
stype: sys::_ze_structure_type_t::ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
commandQueueGroupOrdinal: 0,
@@ -549,40 +728,46 @@ impl<'a> CommandList<'a> {
flags: sys::ze_command_list_flags_t(0),
};
let mut result: sys::ze_command_list_handle_t = ptr::null_mut();
- check!(sys::zeCommandListCreate(ctx.0, dev.0, &desc, &mut result));
- Ok(Self(result, PhantomData))
+ check!(sys::zeCommandListCreate(
+ ctx.as_ffi(),
+ dev.as_ffi(),
+ &desc,
+ &mut result
+ ));
+ Ok(unsafe { Self::from_ffi(result) })
}
- pub fn append_memory_copy<
- T: 'a,
- Dst: Into<BufferPtrMut<'a, T>>,
- Src: Into<BufferPtr<'a, T>>,
- >(
- &mut self,
+ pub fn append_memory_copy<'event, T: 'a, Dst: Into<Slice<'a, T>>, Src: Into<Slice<'a, T>>>(
+ &'a self,
dst: Dst,
src: Src,
- signal: Option<&mut Event<'a>>,
- wait: &mut [Event<'a>],
- ) -> Result<()> {
+ signal: Option<&Event<'event>>,
+ wait: &[Event<'event>],
+ ) -> Result<()>
+ where
+ 'event: 'a,
+ {
let dst = dst.into();
let src = src.into();
let elements = std::cmp::min(dst.len(), src.len());
let length = elements * mem::size_of::<T>();
- unsafe { self.append_memory_copy_unsafe(dst.get(), src.get(), length, signal, wait) }
+ unsafe {
+ self.append_memory_copy_unsafe(dst.as_mut_ptr(), src.as_ptr(), length, signal, wait)
+ }
}
pub unsafe fn append_memory_copy_unsafe(
- &mut self,
+ &self,
dst: *mut c_void,
src: *const c_void,
length: usize,
- signal: Option<&mut Event<'a>>,
- wait: &mut [Event<'a>],
+ signal: Option<&Event>,
+ wait: &[Event],
) -> Result<()> {
- let signal_event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
+ let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut());
let (wait_len, wait_ptr) = Event::raw_slice(wait);
check!(sys::zeCommandListAppendMemoryCopy(
- self.0,
+ self.as_ffi(),
dst,
src,
length,
@@ -593,20 +778,26 @@ impl<'a> CommandList<'a> {
Ok(())
}
- pub fn append_memory_fill<T>(
- &mut self,
- dst: BufferPtrMut<'a, T>,
+ pub fn append_memory_fill<'event, T: 'a, Dst: Into<Slice<'a, T>>>(
+ &'a self,
+ dst: Dst,
pattern: u8,
- signal: Option<&mut Event<'a>>,
- wait: &mut [Event<'a>],
- ) -> Result<()> {
+ signal: Option<&Event<'event>>,
+ wait: &[Event<'event>],
+ ) -> Result<()>
+ where
+ 'event: 'a,
+ {
+ let dst = dst.into();
let raw_pattern = &pattern as *const u8 as *const _;
- let signal_event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
+ let signal_event = signal
+ .map(|e| unsafe { e.as_ffi() })
+ .unwrap_or(ptr::null_mut());
let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) };
let byte_len = dst.len() * mem::size_of::<T>();
check!(sys::zeCommandListAppendMemoryFill(
- self.0,
- dst.get(),
+ self.as_ffi(),
+ dst.as_mut_ptr(),
raw_pattern,
mem::size_of::<u8>(),
byte_len,
@@ -618,17 +809,17 @@ impl<'a> CommandList<'a> {
}
pub unsafe fn append_memory_fill_unsafe<T: Copy + Sized>(
- &mut self,
+ &self,
dst: *mut c_void,
pattern: &T,
byte_size: usize,
- signal: Option<&mut Event<'a>>,
- wait: &mut [Event<'a>],
+ signal: Option<&Event>,
+ wait: &[Event],
) -> Result<()> {
- let signal_event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
+ let signal_event = signal.map(|e| e.as_ffi()).unwrap_or(ptr::null_mut());
let (wait_len, wait_ptr) = Event::raw_slice(wait);
check!(sys::zeCommandListAppendMemoryFill(
- self.0,
+ self.as_ffi(),
dst,
pattern as *const T as *const _,
mem::size_of::<T>(),
@@ -640,23 +831,29 @@ impl<'a> CommandList<'a> {
Ok(())
}
- pub fn append_launch_kernel(
- &mut self,
- kernel: &'a Kernel,
+ pub fn append_launch_kernel<'event, 'kernel>(
+ &'a self,
+ kernel: &'kernel Kernel,
group_count: &[u32; 3],
- signal: Option<&mut Event<'a>>,
- wait: &mut [Event<'a>],
- ) -> Result<()> {
+ signal: Option<&Event<'event>>,
+ wait: &[Event<'event>],
+ ) -> Result<()>
+ where
+ 'event: 'a,
+ 'kernel: 'a,
+ {
let gr_count = sys::ze_group_count_t {
groupCountX: group_count[0],
groupCountY: group_count[1],
groupCountZ: group_count[2],
};
- let signal_event = signal.map(|e| e.0).unwrap_or(ptr::null_mut());
+ let signal_event = signal
+ .map(|e| unsafe { e.as_ffi() })
+ .unwrap_or(ptr::null_mut());
let (wait_len, wait_ptr) = unsafe { Event::raw_slice(wait) };
check!(sys::zeCommandListAppendLaunchKernel(
- self.0,
- kernel.0,
+ self.as_ffi(),
+ kernel.as_ffi(),
&gr_count,
signal_event,
wait_len,
@@ -664,176 +861,129 @@ impl<'a> CommandList<'a> {
));
Ok(())
}
-}
-
-impl<'a> Drop for CommandList<'a> {
- #[allow(unused_must_use)]
- fn drop(&mut self) {
- check_panic! { sys::zeCommandListDestroy(self.0) };
- }
-}
-
-pub struct FenceGuard<'a>(
- sys::ze_fence_handle_t,
- sys::ze_command_list_handle_t,
- PhantomData<&'a ()>,
-);
-impl<'a> FenceGuard<'a> {
- fn new(q: &'a CommandQueue, cmd_list: sys::ze_command_list_handle_t) -> Result<Self> {
- let desc = sys::_ze_fence_desc_t {
- stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_FENCE_DESC,
- pNext: ptr::null(),
- flags: sys::ze_fence_flags_t(0),
- };
- let mut result = ptr::null_mut();
- check!(sys::zeFenceCreate(q.0, &desc, &mut result));
- Ok(FenceGuard(result, cmd_list, PhantomData))
+ pub fn close(&self) -> Result<()> {
+ check!(sys::zeCommandListClose(self.as_ffi()));
+ Ok(())
}
}
-impl<'a> Drop for FenceGuard<'a> {
+impl<'a> Drop for CommandList<'a> {
#[allow(unused_must_use)]
fn drop(&mut self) {
- check_panic! { sys::zeFenceHostSynchronize(self.0, u64::max_value()) };
- check_panic! { sys::zeFenceDestroy(self.0) };
- check_panic! { sys::zeCommandListDestroy(self.1) };
+ check_panic! { sys::zeCommandListDestroy(self.as_ffi()) };
}
}
#[derive(Copy, Clone)]
-pub struct BufferPtr<'a, T> {
- ptr: *const c_void,
+pub struct Slice<'a, T> {
+ ptr: *mut c_void,
+ len: usize,
marker: PhantomData<&'a T>,
- elems: usize,
}
-impl<'a, T> BufferPtr<'a, T> {
- pub unsafe fn get(self) -> *const c_void {
- return self.ptr;
- }
+unsafe impl<'a, T> Send for Slice<'a, T> {}
+unsafe impl<'a, T> Sync for Slice<'a, T> {}
- pub fn len(&self) -> usize {
- self.elems
- }
-}
-
-impl<'a, T> From<&'a [T]> for BufferPtr<'a, T> {
- fn from(s: &'a [T]) -> Self {
- BufferPtr {
- ptr: s.as_ptr() as *const _,
+impl<'a, T> Slice<'a, T> {
+ pub unsafe fn new(ptr: *mut c_void, len: usize) -> Self {
+ Self {
+ ptr,
+ len,
marker: PhantomData,
- elems: s.len(),
}
}
-}
-impl<'a, T: SafeRepr> From<&'a DeviceBuffer<T>> for BufferPtr<'a, T> {
- fn from(b: &'a DeviceBuffer<T>) -> Self {
- BufferPtr {
- ptr: b.ptr as *const _,
- marker: PhantomData,
- elems: b.len(),
- }
+ pub fn as_ptr(&self) -> *const c_void {
+ self.ptr
}
-}
-
-#[derive(Copy, Clone)]
-pub struct BufferPtrMut<'a, T> {
- ptr: *mut c_void,
- marker: PhantomData<&'a mut T>,
- elems: usize,
-}
-impl<'a, T> BufferPtrMut<'a, T> {
- pub unsafe fn get(self) -> *mut c_void {
- return self.ptr;
+ pub fn as_mut_ptr(&self) -> *mut c_void {
+ self.ptr
}
pub fn len(&self) -> usize {
- self.elems
- }
-}
-
-impl<'a, T> From<&'a mut [T]> for BufferPtrMut<'a, T> {
- fn from(s: &'a mut [T]) -> Self {
- BufferPtrMut {
- ptr: s.as_mut_ptr() as *mut _,
- marker: PhantomData,
- elems: s.len(),
- }
+ self.len
}
}
-impl<'a, T: SafeRepr> From<&'a mut DeviceBuffer<T>> for BufferPtrMut<'a, T> {
- fn from(b: &'a mut DeviceBuffer<T>) -> Self {
- BufferPtrMut {
- ptr: b.ptr as *mut _,
+impl<'a, T> From<&'a [T]> for Slice<'a, T> {
+ fn from(s: &'a [T]) -> Self {
+ Slice {
+ ptr: s.as_ptr() as *mut _,
+ len: s.len(),
marker: PhantomData,
- elems: b.len(),
}
}
}
-impl<'a, T: SafeRepr> From<BufferPtrMut<'a, T>> for BufferPtr<'a, T> {
- fn from(b: BufferPtrMut<'a, T>) -> Self {
- BufferPtr {
+impl<'a, T: Copy> From<&'a DeviceBuffer<'a, T>> for Slice<'a, T> {
+ fn from(b: &'a DeviceBuffer<'a, T>) -> Self {
+ Slice {
ptr: b.ptr,
+ len: b.len,
marker: PhantomData,
- elems: b.len(),
}
}
}
-pub struct EventPool<'a>(sys::ze_event_pool_handle_t, PhantomData<&'a ()>);
+
+// Mutability: yes (appends)
+// Lifetime parent: Context
+pub struct EventPool<'a>(NonNull<sys::_ze_event_pool_handle_t>, PhantomData<&'a ()>);
impl<'a> EventPool<'a> {
pub unsafe fn as_ffi(&self) -> sys::ze_event_pool_handle_t {
- self.0
+ self.0.as_ptr()
}
pub unsafe fn from_ffi(x: sys::ze_event_pool_handle_t) -> Self {
- Self(x, PhantomData)
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x), PhantomData)
}
- pub fn new(ctx: &mut Context, count: u32, dev: Option<&[&'a Device]>) -> Result<Self> {
+
+ pub fn new(ctx: &'a Context, count: u32, devs: Option<&[Device]>) -> Result<Self> {
let desc = sys::ze_event_pool_desc_t {
stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_EVENT_POOL_DESC,
pNext: ptr::null(),
flags: sys::ze_event_pool_flags_t(0),
count: count,
};
- let mut dev = dev.map(|d| d.iter().map(|d| d.0).collect::<Vec<_>>());
- let dev_len = dev.as_ref().map_or(0, |d| d.len() as u32);
- let dev_ptr = dev.as_mut().map_or(ptr::null_mut(), |d| d.as_mut_ptr());
+ let (dev_len, dev_ptr) = devs.map_or((0, ptr::null_mut()), |devs| {
+ (devs.len(), devs.as_ptr() as *mut _)
+ });
let mut result = ptr::null_mut();
check!(sys::zeEventPoolCreate(
- ctx.0,
+ ctx.as_ffi(),
&desc,
- dev_len,
+ dev_len as u32,
dev_ptr,
&mut result
));
- Ok(Self(result, PhantomData))
+ Ok(unsafe { Self::from_ffi(result) })
}
}
impl<'a> Drop for EventPool<'a> {
- #[allow(unused_must_use)]
fn drop(&mut self) {
- check_panic! { sys::zeEventPoolDestroy(self.0) };
+ check_panic! { sys::zeEventPoolDestroy(self.as_ffi()) };
}
}
-pub struct Event<'a>(sys::ze_event_handle_t, PhantomData<&'a ()>);
+pub struct Event<'a>(NonNull<sys::_ze_event_handle_t>, PhantomData<&'a ()>);
impl<'a> Event<'a> {
pub unsafe fn as_ffi(&self) -> sys::ze_event_handle_t {
- self.0
+ self.0.as_ptr()
}
-
pub unsafe fn from_ffi(x: sys::ze_event_handle_t) -> Self {
- Self(x, PhantomData)
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x), PhantomData)
}
- pub fn new(pool: &'a EventPool, index: u32) -> Result<Self> {
+ pub fn new(pool: &'a EventPool<'a>, index: u32) -> Result<Self> {
let desc = sys::ze_event_desc_t {
stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_EVENT_DESC,
pNext: ptr::null(),
@@ -842,36 +992,37 @@ impl<'a> Event<'a> {
wait: sys::ze_event_scope_flags_t(0),
};
let mut result = ptr::null_mut();
- check!(sys::zeEventCreate(pool.0, &desc, &mut result));
- Ok(Self(result, PhantomData))
+ check!(sys::zeEventCreate(pool.as_ffi(), &desc, &mut result));
+ Ok(unsafe { Self::from_ffi(result) })
}
- unsafe fn raw_slice(e: &mut [Event]) -> (u32, *mut sys::ze_event_handle_t) {
+ unsafe fn raw_slice(e: &[Event]) -> (u32, *mut sys::ze_event_handle_t) {
let ptr = if e.len() == 0 {
- ptr::null_mut()
+ ptr::null()
} else {
- e.as_mut_ptr()
+ e.as_ptr()
};
(e.len() as u32, ptr as *mut sys::ze_event_handle_t)
}
}
impl<'a> Drop for Event<'a> {
- #[allow(unused_must_use)]
fn drop(&mut self) {
- check_panic! { sys::zeEventDestroy(self.0) };
+ check_panic! { sys::zeEventDestroy(self.as_ffi()) };
}
}
-pub struct Kernel<'a>(sys::ze_kernel_handle_t, PhantomData<&'a ()>);
+pub struct Kernel<'a>(NonNull<sys::_ze_kernel_handle_t>, PhantomData<&'a ()>);
impl<'a> Kernel<'a> {
pub unsafe fn as_ffi(&self) -> sys::ze_kernel_handle_t {
- self.0
+ self.0.as_ptr()
}
-
pub unsafe fn from_ffi(x: sys::ze_kernel_handle_t) -> Self {
- Self(x, PhantomData)
+ if x == ptr::null_mut() {
+ panic!("FFI handle can't be zero")
+ }
+ Self(NonNull::new_unchecked(x), PhantomData)
}
pub fn new_resident(module: &'a Module, name: &CStr) -> Result<Self> {
@@ -882,26 +1033,23 @@ impl<'a> Kernel<'a> {
pKernelName: name.as_ptr() as *const _,
};
let mut result = ptr::null_mut();
- check!(sys::zeKernelCreate(module.0, &desc, &mut result));
- Ok(Self(result, PhantomData))
+ check!(sys::zeKernelCreate(module.as_ffi(), &desc, &mut result));
+ Ok(unsafe { Self::from_ffi(result) })
}
- pub fn set_indirect_access(
- &mut self,
- flags: sys::ze_kernel_indirect_access_flags_t,
- ) -> Result<()> {
- check!(sys::zeKernelSetIndirectAccess(self.0, flags));
+ pub fn set_indirect_access(&self, flags: sys::ze_kernel_indirect_access_flags_t) -> Result<()> {
+ check!(sys::zeKernelSetIndirectAccess(self.as_ffi(), flags));
Ok(())
}
- pub fn set_arg_buffer<T: 'a, Buff: Into<BufferPtr<'a, T>>>(
+ pub fn set_arg_buffer<T: 'a, Buff: Into<Slice<'a, T>>>(
&self,
index: u32,
buff: Buff,
) -> Result<()> {
- let ptr = unsafe { buff.into().get() };
+ let ptr = buff.into().as_mut_ptr();
check!(sys::zeKernelSetArgumentValue(
- self.0,
+ self.as_ffi(),
index,
mem::size_of::<*const ()>(),
&ptr as *const _ as *const _,
@@ -911,7 +1059,7 @@ impl<'a> Kernel<'a> {
pub fn set_arg_scalar<T: Copy>(&self, index: u32, value: &T) -> Result<()> {
check!(sys::zeKernelSetArgumentValue(
- self.0,
+ self.as_ffi(),
index,
mem::size_of::<T>(),
value as *const T as *const _,
@@ -920,18 +1068,26 @@ impl<'a> Kernel<'a> {
}
pub unsafe fn set_arg_raw(&self, index: u32, size: usize, value: *const c_void) -> Result<()> {
- check!(sys::zeKernelSetArgumentValue(self.0, index, size, value));
+ check!(sys::zeKernelSetArgumentValue(
+ self.as_ffi(),
+ index,
+ size,
+ value
+ ));
Ok(())
}
pub fn set_group_size(&self, x: u32, y: u32, z: u32) -> Result<()> {
- check!(sys::zeKernelSetGroupSize(self.0, x, y, z));
+ check!(sys::zeKernelSetGroupSize(self.as_ffi(), x, y, z));
Ok(())
}
pub fn get_properties(&self) -> Result<Box<sys::ze_kernel_properties_t>> {
let mut props = Box::new(unsafe { mem::zeroed::<sys::ze_kernel_properties_t>() });
- check!(sys::zeKernelGetProperties(self.0, props.as_mut() as *mut _));
+ check!(sys::zeKernelGetProperties(
+ self.as_ffi(),
+ props.as_mut() as *mut _
+ ));
Ok(props)
}
}
@@ -939,7 +1095,7 @@ impl<'a> Kernel<'a> {
impl<'a> Drop for Kernel<'a> {
#[allow(unused_must_use)]
fn drop(&mut self) {
- check_panic! { sys::zeKernelDestroy(self.0) };
+ check_panic! { sys::zeKernelDestroy(self.as_ffi()) };
}
}
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 14d3284..94114db 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -201,8 +201,8 @@ impl<T: Debug> error::Error for DisplayError<T> {}
fn test_ptx_assert<
'a,
- Input: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq,
- Output: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq,
+ Input: From<u8> + Debug + Copy + PartialEq,
+ Output: From<u8> + Debug + Copy + PartialEq,
>(
name: &str,
ptx_text: &'a str,
@@ -220,10 +220,7 @@ fn test_ptx_assert<
Ok(())
}
-fn run_spirv<
- Input: From<u8> + ze::SafeRepr + Copy + Debug,
- Output: From<u8> + ze::SafeRepr + Copy + Debug,
->(
+fn run_spirv<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug>(
name: &CStr,
module: translate::Module,
input: &[Input],
@@ -242,25 +239,25 @@ fn run_spirv<
.get(name.to_str().unwrap())
.map(|info| info.uses_shared_mem)
.unwrap_or(false);
- let mut result = vec![0u8.into(); output.len()];
+ let result = vec![0u8.into(); output.len()];
{
let mut drivers = ze::Driver::get()?;
let drv = drivers.drain(0..1).next().unwrap();
- let mut ctx = ze::Context::new(&drv)?;
let mut devices = drv.devices()?;
let dev = devices.drain(0..1).next().unwrap();
- let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
+ let ctx = ze::Context::new(drv, None)?;
+ let queue = ze::CommandQueue::new(&ctx, dev)?;
let (module, maybe_log) = match module.should_link_ptx_impl {
Some(ptx_impl) => ze::Module::build_link_spirv(
- &mut ctx,
- &dev,
+ &ctx,
+ dev,
&[ptx_impl, byte_il],
Some(module.build_options.as_c_str()),
),
None => {
let (module, log) = ze::Module::build_spirv_logged(
- &mut ctx,
- &dev,
+ &ctx,
+ dev,
byte_il,
Some(module.build_options.as_c_str()),
);
@@ -271,38 +268,38 @@ fn run_spirv<
Ok(m) => m,
Err(err) => {
let raw_err_string = maybe_log
- .map(|log| log.get_cstring())
+ .map(|log| log.to_cstring())
.transpose()?
.unwrap_or(CString::default());
let err_string = raw_err_string.to_string_lossy();
panic!("{:?}\n{}", err, err_string);
}
};
- let mut kernel = ze::Kernel::new_resident(&module, name)?;
+ let kernel = ze::Kernel::new_resident(&module, name)?;
kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
)?;
- let mut inp_b = ze::DeviceBuffer::<Input>::new(&mut ctx, &dev, cmp::max(input.len(), 1))?;
- let mut out_b = ze::DeviceBuffer::<Output>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?;
- let inp_b_ptr_mut: ze::BufferPtrMut<Input> = (&mut inp_b).into();
- let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
+ let inp_b = ze::DeviceBuffer::<Input>::new(&ctx, dev, cmp::max(input.len(), 1))?;
+ let out_b = ze::DeviceBuffer::<Output>::new(&ctx, dev, cmp::max(output.len(), 1))?;
+ let event_pool = ze::EventPool::new(&ctx, 3, Some(&[dev]))?;
let ev0 = ze::Event::new(&event_pool, 0)?;
let ev1 = ze::Event::new(&event_pool, 1)?;
- let mut ev2 = ze::Event::new(&event_pool, 2)?;
- let mut cmd_list = ze::CommandList::new(&mut ctx, &dev)?;
- let out_b_ptr_mut: ze::BufferPtrMut<Output> = (&mut out_b).into();
- let mut init_evs = [ev0, ev1];
- cmd_list.append_memory_copy(inp_b_ptr_mut, input, Some(&mut init_evs[0]), &mut [])?;
- cmd_list.append_memory_fill(out_b_ptr_mut, 0, Some(&mut init_evs[1]), &mut [])?;
- kernel.set_group_size(1, 1, 1)?;
- kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
- kernel.set_arg_buffer(1, out_b_ptr_mut)?;
- if use_shared_mem {
- unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
+ let ev2 = ze::Event::new(&event_pool, 2)?;
+ {
+ let cmd_list = ze::CommandList::new(&ctx, dev)?;
+ let init_evs = [ev0, ev1];
+ cmd_list.append_memory_copy(&inp_b, input, Some(&init_evs[0]), &[])?;
+ cmd_list.append_memory_fill(&out_b, 0, Some(&init_evs[1]), &[])?;
+ kernel.set_group_size(1, 1, 1)?;
+ kernel.set_arg_buffer(0, &inp_b)?;
+ kernel.set_arg_buffer(1, &out_b)?;
+ if use_shared_mem {
+ unsafe { kernel.set_arg_raw(2, 128, ptr::null())? };
+ }
+ cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &init_evs)?;
+ cmd_list.append_memory_copy(&*result, &out_b, None, &[ev2])?;
+ queue.execute_and_synchronize(cmd_list)?;
}
- cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&mut ev2), &mut init_evs)?;
- cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, &mut [ev2])?;
- queue.execute(cmd_list)?;
}
Ok(result)
}
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs
index 2d72460..5ef427e 100644
--- a/zluda/src/impl/context.rs
+++ b/zluda/src/impl/context.rs
@@ -98,8 +98,8 @@ pub struct ContextData {
impl ContextData {
pub fn new(
- l0_ctx: &mut l0::Context,
- l0_dev: &l0::Device,
+ l0_ctx: &'static l0::Context,
+ l0_dev: l0::Device,
flags: c_uint,
is_primary: bool,
dev: *mut device::Device,
@@ -137,7 +137,7 @@ pub fn create_v2(
let dev_ptr = dev as *mut _;
let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
&mut dev.l0_context,
- &dev.base,
+ dev.base,
flags,
false,
dev_ptr as *mut _,
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index 29cac2d..63bf39f 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -18,7 +18,7 @@ pub struct Index(pub c_int);
pub struct Device {
pub index: Index,
pub base: l0::Device,
- pub default_queue: l0::CommandQueue,
+ pub default_queue: l0::CommandQueue<'static>,
pub l0_context: l0::Context,
pub primary_context: context::Context,
properties: Option<Box<l0::sys::ze_device_properties_t>>,
@@ -31,12 +31,13 @@ unsafe impl Send for Device {}
impl Device {
// Unsafe because it does not fully initalize primary_context
+ // and we transmute lifetimes left and right
unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
- let mut ctx = l0::Context::new(drv)?;
- let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?;
+ let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?;
+ let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?;
let primary_context = context::Context::new(context::ContextData::new(
- &mut ctx,
- &l0_dev,
+ mem::transmute(&ctx),
+ l0_dev,
0,
true,
ptr::null_mut(),
@@ -58,20 +59,18 @@ impl Device {
if let Some(ref prop) = self.properties {
return Ok(prop);
}
- match self.base.get_properties() {
- Ok(prop) => Ok(self.properties.get_or_insert(prop)),
- Err(e) => Err(e),
- }
+ let mut props = Default::default();
+ self.base.get_properties(&mut props)?;
+ Ok(self.properties.get_or_insert(Box::new(props)))
}
fn get_image_properties(&mut self) -> l0::Result<&l0::sys::ze_device_image_properties_t> {
if let Some(ref prop) = self.image_properties {
return Ok(prop);
}
- match self.base.get_image_properties() {
- Ok(prop) => Ok(self.image_properties.get_or_insert(prop)),
- Err(e) => Err(e),
- }
+ let mut props = Default::default();
+ self.base.get_image_properties(&mut props)?;
+ Ok(self.image_properties.get_or_insert(Box::new(props)))
}
fn get_memory_properties(&mut self) -> l0::Result<&[l0::sys::ze_device_memory_properties_t]> {
@@ -88,10 +87,9 @@ impl Device {
if let Some(ref prop) = self.compute_properties {
return Ok(prop);
}
- match self.base.get_compute_properties() {
- Ok(prop) => Ok(self.compute_properties.get_or_insert(prop)),
- Err(e) => Err(e),
- }
+ let mut props = Default::default();
+ self.base.get_compute_properties(&mut props)?;
+ Ok(self.compute_properties.get_or_insert(Box::new(props)))
}
pub fn late_init(&mut self) {
@@ -351,7 +349,11 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> {
}
// TODO: add support if Level 0 exposes it
-pub fn get_luid(luid: *mut c_char, dev_node_mask: *mut c_uint, _dev_idx: Index) -> Result<(), CUresult> {
+pub fn get_luid(
+ luid: *mut c_char,
+ dev_node_mask: *mut c_uint,
+ _dev_idx: Index,
+) -> Result<(), CUresult> {
unsafe { ptr::write_bytes(luid, 0u8, 8) };
unsafe { *dev_node_mask = 0 };
Ok(())
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs
index 11f15e6..e236160 100644
--- a/zluda/src/impl/function.rs
+++ b/zluda/src/impl/function.rs
@@ -144,14 +144,14 @@ pub fn launch_kernel(
func.base
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
func.legacy_args.reset();
- let mut cmd_list = stream.command_list()?;
+ let cmd_list = stream.command_list()?;
cmd_list.append_launch_kernel(
&mut func.base,
&[grid_dim_x, grid_dim_y, grid_dim_z],
None,
&mut [],
)?;
- stream.queue.execute(cmd_list)?;
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok(())
})?
}
diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs
index f33a08c..5db6472 100644
--- a/zluda/src/impl/memory.rs
+++ b/zluda/src/impl/memory.rs
@@ -4,7 +4,7 @@ use std::{ffi::c_void, mem};
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
let ptr = GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
- Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
+ Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?)
})??;
unsafe { *dptr = ptr };
Ok(())
@@ -12,9 +12,9 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult>
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let mut cmd_list = stream.command_list()?;
- unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
- stream.queue.execute(cmd_list)?;
+ let cmd_list = stream.command_list()?;
+ unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? };
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
@@ -22,29 +22,29 @@ pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
- Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
+ Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?)
})
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
}
pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let mut cmd_list = stream.command_list()?;
+ let cmd_list = stream.command_list()?;
unsafe {
cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::<u32>() * n, None, &mut [])
}?;
- stream.queue.execute(cmd_list)?;
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let mut cmd_list = stream.command_list()?;
+ let cmd_list = stream.command_list()?;
unsafe {
cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::<u8>() * n, None, &mut [])
}?;
- stream.queue.execute(cmd_list)?;
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs
index 98580f8..6268904 100644
--- a/zluda/src/impl/module.rs
+++ b/zluda/src/impl/module.rs
@@ -41,7 +41,7 @@ pub struct SpirvModule {
}
pub struct CompiledModule {
- pub base: l0::Module,
+ pub base: l0::Module<'static>,
pub kernels: HashMap<CString, Box<Function>>,
}
@@ -78,7 +78,11 @@ impl SpirvModule {
})
}
- pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> {
+ pub fn compile<'a>(
+ &self,
+ ctx: &'a l0::Context,
+ dev: l0::Device,
+ ) -> Result<l0::Module<'a>, CUresult> {
let byte_il = unsafe {
slice::from_raw_parts(
self.binaries.as_ptr() as *const u8,
@@ -86,13 +90,11 @@ impl SpirvModule {
)
};
let l0_module = match self.should_link_ptx_impl {
- None => {
- l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str()))
- }
+ None => l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())),
Some(ptx_impl) => {
l0::Module::build_link_spirv(
ctx,
- &dev,
+ dev,
&[ptx_impl, byte_il],
Some(self.build_options.as_c_str()),
)
@@ -119,7 +121,7 @@ pub fn get_function(
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
let new_module = CompiledModule {
- base: module.spirv.compile(&mut device.l0_context, &device.base)?,
+ base: module.spirv.compile(&mut device.l0_context, device.base)?,
kernels: HashMap::new(),
};
entry.insert(new_module)
@@ -135,7 +137,7 @@ pub fn get_function(
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
})
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
- let mut kernel =
+ let kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
kernel.set_indirect_access(
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
@@ -165,7 +167,7 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device };
- let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?;
+ let l0_module = spirv_data.compile(&device.l0_context, device.base)?;
let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule {
base: l0_module,
diff --git a/zluda/src/impl/stream.rs b/zluda/src/impl/stream.rs
index e212dfc..0fafe92 100644
--- a/zluda/src/impl/stream.rs
+++ b/zluda/src/impl/stream.rs
@@ -33,11 +33,11 @@ impl HasLivenessCookie for StreamData {
pub struct StreamData {
pub context: *mut ContextData,
- pub queue: l0::CommandQueue,
+ pub queue: l0::CommandQueue<'static>,
}
impl StreamData {
- pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> {
+ pub fn new_unitialized(ctx: &'static l0::Context, dev: l0::Device) -> Result<Self, CUresult> {
Ok(StreamData {
context: ptr::null_mut(),
queue: l0::CommandQueue::new(ctx, dev)?,
@@ -45,7 +45,7 @@ impl StreamData {
}
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
- let l0_dev = &unsafe { &*ctx.device }.base;
+ let l0_dev = unsafe { &*ctx.device }.base;
Ok(StreamData {
context: ctx as *mut _,
queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
@@ -55,7 +55,7 @@ impl StreamData {
pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> {
let ctx = unsafe { &mut *self.context };
let dev = unsafe { &mut *ctx.device };
- l0::CommandList::new(&mut dev.l0_context, &dev.base)
+ l0::CommandList::new(&mut dev.l0_context, dev.base)
}
}
diff --git a/zluda_ml/src/impl.rs b/zluda_ml/src/impl.rs
index 75f3ca2..1068b00 100644
--- a/zluda_ml/src/impl.rs
+++ b/zluda_ml/src/impl.rs
@@ -127,7 +127,8 @@ pub(crate) fn system_get_driver_version(
len: 0,
};
for d in drivers {
- let props = d.get_properties()?;
+ let mut props = Default::default();
+ d.get_properties(&mut props)?;
let driver_version = props.driverVersion;
write!(&mut output_write, "{}", driver_version)
.map_err(|_| nvmlReturn_t::NVML_ERROR_UNKNOWN)?;