aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-05-27 02:05:17 +0200
committerAndrzej Janik <[email protected]>2021-05-27 02:05:17 +0200
commite40785aa7491de16c65de7aa599105102ffa7355 (patch)
tree87b4b16dbf6318aae8456a04ab6af574d2238ddb /zluda
parent58a7fe53c6feaf96156c455b7c3b1def9d7e6d56 (diff)
downloadZLUDA-e40785aa7491de16c65de7aa599105102ffa7355.tar.gz
ZLUDA-e40785aa7491de16c65de7aa599105102ffa7355.zip
Refactor L0 bindings
Diffstat (limited to 'zluda')
-rw-r--r--zluda/src/impl/context.rs6
-rw-r--r--zluda/src/impl/device.rs38
-rw-r--r--zluda/src/impl/function.rs4
-rw-r--r--zluda/src/impl/memory.rs18
-rw-r--r--zluda/src/impl/module.rs20
-rw-r--r--zluda/src/impl/stream.rs8
6 files changed, 49 insertions, 45 deletions
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs
index 2d72460..5ef427e 100644
--- a/zluda/src/impl/context.rs
+++ b/zluda/src/impl/context.rs
@@ -98,8 +98,8 @@ pub struct ContextData {
impl ContextData {
pub fn new(
- l0_ctx: &mut l0::Context,
- l0_dev: &l0::Device,
+ l0_ctx: &'static l0::Context,
+ l0_dev: l0::Device,
flags: c_uint,
is_primary: bool,
dev: *mut device::Device,
@@ -137,7 +137,7 @@ pub fn create_v2(
let dev_ptr = dev as *mut _;
let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
&mut dev.l0_context,
- &dev.base,
+ dev.base,
flags,
false,
dev_ptr as *mut _,
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index 29cac2d..63bf39f 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -18,7 +18,7 @@ pub struct Index(pub c_int);
pub struct Device {
pub index: Index,
pub base: l0::Device,
- pub default_queue: l0::CommandQueue,
+ pub default_queue: l0::CommandQueue<'static>,
pub l0_context: l0::Context,
pub primary_context: context::Context,
properties: Option<Box<l0::sys::ze_device_properties_t>>,
@@ -31,12 +31,13 @@ unsafe impl Send for Device {}
impl Device {
// Unsafe because it does not fully initalize primary_context
+ // and we transmute lifetimes left and right
unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
- let mut ctx = l0::Context::new(drv)?;
- let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?;
+ let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?;
+ let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?;
let primary_context = context::Context::new(context::ContextData::new(
- &mut ctx,
- &l0_dev,
+ mem::transmute(&ctx),
+ l0_dev,
0,
true,
ptr::null_mut(),
@@ -58,20 +59,18 @@ impl Device {
if let Some(ref prop) = self.properties {
return Ok(prop);
}
- match self.base.get_properties() {
- Ok(prop) => Ok(self.properties.get_or_insert(prop)),
- Err(e) => Err(e),
- }
+ let mut props = Default::default();
+ self.base.get_properties(&mut props)?;
+ Ok(self.properties.get_or_insert(Box::new(props)))
}
fn get_image_properties(&mut self) -> l0::Result<&l0::sys::ze_device_image_properties_t> {
if let Some(ref prop) = self.image_properties {
return Ok(prop);
}
- match self.base.get_image_properties() {
- Ok(prop) => Ok(self.image_properties.get_or_insert(prop)),
- Err(e) => Err(e),
- }
+ let mut props = Default::default();
+ self.base.get_image_properties(&mut props)?;
+ Ok(self.image_properties.get_or_insert(Box::new(props)))
}
fn get_memory_properties(&mut self) -> l0::Result<&[l0::sys::ze_device_memory_properties_t]> {
@@ -88,10 +87,9 @@ impl Device {
if let Some(ref prop) = self.compute_properties {
return Ok(prop);
}
- match self.base.get_compute_properties() {
- Ok(prop) => Ok(self.compute_properties.get_or_insert(prop)),
- Err(e) => Err(e),
- }
+ let mut props = Default::default();
+ self.base.get_compute_properties(&mut props)?;
+ Ok(self.compute_properties.get_or_insert(Box::new(props)))
}
pub fn late_init(&mut self) {
@@ -351,7 +349,11 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> {
}
// TODO: add support if Level 0 exposes it
-pub fn get_luid(luid: *mut c_char, dev_node_mask: *mut c_uint, _dev_idx: Index) -> Result<(), CUresult> {
+pub fn get_luid(
+ luid: *mut c_char,
+ dev_node_mask: *mut c_uint,
+ _dev_idx: Index,
+) -> Result<(), CUresult> {
unsafe { ptr::write_bytes(luid, 0u8, 8) };
unsafe { *dev_node_mask = 0 };
Ok(())
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs
index 11f15e6..e236160 100644
--- a/zluda/src/impl/function.rs
+++ b/zluda/src/impl/function.rs
@@ -144,14 +144,14 @@ pub fn launch_kernel(
func.base
.set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
func.legacy_args.reset();
- let mut cmd_list = stream.command_list()?;
+ let cmd_list = stream.command_list()?;
cmd_list.append_launch_kernel(
&mut func.base,
&[grid_dim_x, grid_dim_y, grid_dim_z],
None,
&mut [],
)?;
- stream.queue.execute(cmd_list)?;
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok(())
})?
}
diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs
index f33a08c..5db6472 100644
--- a/zluda/src/impl/memory.rs
+++ b/zluda/src/impl/memory.rs
@@ -4,7 +4,7 @@ use std::{ffi::c_void, mem};
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
let ptr = GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
- Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
+ Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?)
})??;
unsafe { *dptr = ptr };
Ok(())
@@ -12,9 +12,9 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult>
pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let mut cmd_list = stream.command_list()?;
- unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
- stream.queue.execute(cmd_list)?;
+ let cmd_list = stream.command_list()?;
+ unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? };
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
@@ -22,29 +22,29 @@ pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<
pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
- Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
+ Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?)
})
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
}
pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let mut cmd_list = stream.command_list()?;
+ let cmd_list = stream.command_list()?;
unsafe {
cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::<u32>() * n, None, &mut [])
}?;
- stream.queue.execute(cmd_list)?;
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let mut cmd_list = stream.command_list()?;
+ let cmd_list = stream.command_list()?;
unsafe {
cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::<u8>() * n, None, &mut [])
}?;
- stream.queue.execute(cmd_list)?;
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs
index 98580f8..6268904 100644
--- a/zluda/src/impl/module.rs
+++ b/zluda/src/impl/module.rs
@@ -41,7 +41,7 @@ pub struct SpirvModule {
}
pub struct CompiledModule {
- pub base: l0::Module,
+ pub base: l0::Module<'static>,
pub kernels: HashMap<CString, Box<Function>>,
}
@@ -78,7 +78,11 @@ impl SpirvModule {
})
}
- pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> {
+ pub fn compile<'a>(
+ &self,
+ ctx: &'a l0::Context,
+ dev: l0::Device,
+ ) -> Result<l0::Module<'a>, CUresult> {
let byte_il = unsafe {
slice::from_raw_parts(
self.binaries.as_ptr() as *const u8,
@@ -86,13 +90,11 @@ impl SpirvModule {
)
};
let l0_module = match self.should_link_ptx_impl {
- None => {
- l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str()))
- }
+ None => l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())),
Some(ptx_impl) => {
l0::Module::build_link_spirv(
ctx,
- &dev,
+ dev,
&[ptx_impl, byte_il],
Some(self.build_options.as_c_str()),
)
@@ -119,7 +121,7 @@ pub fn get_function(
hash_map::Entry::Occupied(entry) => entry.into_mut(),
hash_map::Entry::Vacant(entry) => {
let new_module = CompiledModule {
- base: module.spirv.compile(&mut device.l0_context, &device.base)?,
+ base: module.spirv.compile(&mut device.l0_context, device.base)?,
kernels: HashMap::new(),
};
entry.insert(new_module)
@@ -135,7 +137,7 @@ pub fn get_function(
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
})
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
- let mut kernel =
+ let kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
kernel.set_indirect_access(
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
@@ -165,7 +167,7 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<
pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
let module = GlobalState::lock_current_context(|ctx| {
let device = unsafe { &mut *ctx.device };
- let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?;
+ let l0_module = spirv_data.compile(&device.l0_context, device.base)?;
let mut device_binaries = HashMap::new();
let compiled_module = CompiledModule {
base: l0_module,
diff --git a/zluda/src/impl/stream.rs b/zluda/src/impl/stream.rs
index e212dfc..0fafe92 100644
--- a/zluda/src/impl/stream.rs
+++ b/zluda/src/impl/stream.rs
@@ -33,11 +33,11 @@ impl HasLivenessCookie for StreamData {
pub struct StreamData {
pub context: *mut ContextData,
- pub queue: l0::CommandQueue,
+ pub queue: l0::CommandQueue<'static>,
}
impl StreamData {
- pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> {
+ pub fn new_unitialized(ctx: &'static l0::Context, dev: l0::Device) -> Result<Self, CUresult> {
Ok(StreamData {
context: ptr::null_mut(),
queue: l0::CommandQueue::new(ctx, dev)?,
@@ -45,7 +45,7 @@ impl StreamData {
}
pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
- let l0_dev = &unsafe { &*ctx.device }.base;
+ let l0_dev = unsafe { &*ctx.device }.base;
Ok(StreamData {
context: ctx as *mut _,
queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
@@ -55,7 +55,7 @@ impl StreamData {
pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> {
let ctx = unsafe { &mut *self.context };
let dev = unsafe { &mut *ctx.device };
- l0::CommandList::new(&mut dev.l0_context, &dev.base)
+ l0::CommandList::new(&mut dev.l0_context, dev.base)
}
}