diff options
author | Andrzej Janik <[email protected]> | 2021-05-27 02:05:17 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2021-05-27 02:05:17 +0200 |
commit | e40785aa7491de16c65de7aa599105102ffa7355 (patch) | |
tree | 87b4b16dbf6318aae8456a04ab6af574d2238ddb /zluda | |
parent | 58a7fe53c6feaf96156c455b7c3b1def9d7e6d56 (diff) | |
download | ZLUDA-e40785aa7491de16c65de7aa599105102ffa7355.tar.gz ZLUDA-e40785aa7491de16c65de7aa599105102ffa7355.zip |
Refactor L0 bindings
Diffstat (limited to 'zluda')
-rw-r--r-- | zluda/src/impl/context.rs | 6 | ||||
-rw-r--r-- | zluda/src/impl/device.rs | 38 | ||||
-rw-r--r-- | zluda/src/impl/function.rs | 4 | ||||
-rw-r--r-- | zluda/src/impl/memory.rs | 18 | ||||
-rw-r--r-- | zluda/src/impl/module.rs | 20 | ||||
-rw-r--r-- | zluda/src/impl/stream.rs | 8 |
6 files changed, 49 insertions, 45 deletions
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index 2d72460..5ef427e 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -98,8 +98,8 @@ pub struct ContextData { impl ContextData { pub fn new( - l0_ctx: &mut l0::Context, - l0_dev: &l0::Device, + l0_ctx: &'static l0::Context, + l0_dev: l0::Device, flags: c_uint, is_primary: bool, dev: *mut device::Device, @@ -137,7 +137,7 @@ pub fn create_v2( let dev_ptr = dev as *mut _; let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( &mut dev.l0_context, - &dev.base, + dev.base, flags, false, dev_ptr as *mut _, diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 29cac2d..63bf39f 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -18,7 +18,7 @@ pub struct Index(pub c_int); pub struct Device { pub index: Index, pub base: l0::Device, - pub default_queue: l0::CommandQueue, + pub default_queue: l0::CommandQueue<'static>, pub l0_context: l0::Context, pub primary_context: context::Context, properties: Option<Box<l0::sys::ze_device_properties_t>>, @@ -31,12 +31,13 @@ unsafe impl Send for Device {} impl Device { // Unsafe because it does not fully initalize primary_context + // and we transmute lifetimes left and right unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> { - let mut ctx = l0::Context::new(drv)?; - let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?; + let ctx = l0::Context::new(*drv, Some(&[l0_dev]))?; + let queue = l0::CommandQueue::new(mem::transmute(&ctx), l0_dev)?; let primary_context = context::Context::new(context::ContextData::new( - &mut ctx, - &l0_dev, + mem::transmute(&ctx), + l0_dev, 0, true, ptr::null_mut(), @@ -58,20 +59,18 @@ impl Device { if let Some(ref prop) = self.properties { return Ok(prop); } - match self.base.get_properties() { - Ok(prop) => Ok(self.properties.get_or_insert(prop)), - Err(e) => Err(e), - } + let mut props = Default::default(); + self.base.get_properties(&mut props)?; + Ok(self.properties.get_or_insert(Box::new(props))) } fn get_image_properties(&mut self) -> l0::Result<&l0::sys::ze_device_image_properties_t> { if let Some(ref prop) = self.image_properties { return Ok(prop); } - match self.base.get_image_properties() { - Ok(prop) => Ok(self.image_properties.get_or_insert(prop)), - Err(e) => Err(e), - } + let mut props = Default::default(); + self.base.get_image_properties(&mut props)?; + Ok(self.image_properties.get_or_insert(Box::new(props))) } fn get_memory_properties(&mut self) -> l0::Result<&[l0::sys::ze_device_memory_properties_t]> { @@ -88,10 +87,9 @@ impl Device { if let Some(ref prop) = self.compute_properties { return Ok(prop); } - match self.base.get_compute_properties() { - Ok(prop) => Ok(self.compute_properties.get_or_insert(prop)), - Err(e) => Err(e), - } + let mut props = Default::default(); + self.base.get_compute_properties(&mut props)?; + Ok(self.compute_properties.get_or_insert(Box::new(props))) } pub fn late_init(&mut self) { @@ -351,7 +349,11 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> { } // TODO: add support if Level 0 exposes it -pub fn get_luid(luid: *mut c_char, dev_node_mask: *mut c_uint, _dev_idx: Index) -> Result<(), CUresult> { +pub fn get_luid( + luid: *mut c_char, + dev_node_mask: *mut c_uint, + _dev_idx: Index, +) -> Result<(), CUresult> { unsafe { ptr::write_bytes(luid, 0u8, 8) }; unsafe { *dev_node_mask = 0 }; Ok(()) diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs index 11f15e6..e236160 100644 --- a/zluda/src/impl/function.rs +++ b/zluda/src/impl/function.rs @@ -144,14 +144,14 @@ pub fn launch_kernel( func.base .set_group_size(block_dim_x, block_dim_y, block_dim_z)?; func.legacy_args.reset(); - let mut cmd_list = stream.command_list()?; + let cmd_list = stream.command_list()?; cmd_list.append_launch_kernel( &mut func.base, &[grid_dim_x, grid_dim_y, grid_dim_z], None, &mut [], )?; - stream.queue.execute(cmd_list)?; + stream.queue.execute_and_synchronize(cmd_list)?; Ok(()) })? } diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index f33a08c..5db6472 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -4,7 +4,7 @@ use std::{ffi::c_void, mem}; pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
let ptr = GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
- Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
+ Ok::<_, CUresult>(dev.l0_context.mem_alloc_device(bytesize, 0, dev.base)?)
})??;
unsafe { *dptr = ptr };
Ok(())
@@ -12,9 +12,9 @@ pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let mut cmd_list = stream.command_list()?;
- unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
- stream.queue.execute(cmd_list)?;
+ let cmd_list = stream.command_list()?;
+ unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut [])? };
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
@@ -22,29 +22,29 @@ pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result< pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
GlobalState::lock_current_context(|ctx| {
let dev = unsafe { &mut *ctx.device };
- Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
+ Ok::<_, CUresult>(dev.l0_context.mem_free(ptr)?)
})
.map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
}
pub(crate) fn set_d32_v2(dst: *mut c_void, ui: u32, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let mut cmd_list = stream.command_list()?;
+ let cmd_list = stream.command_list()?;
unsafe {
cmd_list.append_memory_fill_unsafe(dst, &ui, mem::size_of::<u32>() * n, None, &mut [])
}?;
- stream.queue.execute(cmd_list)?;
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
pub(crate) fn set_d8_v2(dst: *mut c_void, uc: u8, n: usize) -> Result<(), CUresult> {
GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
- let mut cmd_list = stream.command_list()?;
+ let cmd_list = stream.command_list()?;
unsafe {
cmd_list.append_memory_fill_unsafe(dst, &uc, mem::size_of::<u8>() * n, None, &mut [])
}?;
- stream.queue.execute(cmd_list)?;
+ stream.queue.execute_and_synchronize(cmd_list)?;
Ok::<_, CUresult>(())
})?
}
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 98580f8..6268904 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -41,7 +41,7 @@ pub struct SpirvModule { } pub struct CompiledModule { - pub base: l0::Module, + pub base: l0::Module<'static>, pub kernels: HashMap<CString, Box<Function>>, } @@ -78,7 +78,11 @@ impl SpirvModule { }) } - pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> { + pub fn compile<'a>( + &self, + ctx: &'a l0::Context, + dev: l0::Device, + ) -> Result<l0::Module<'a>, CUresult> { let byte_il = unsafe { slice::from_raw_parts( self.binaries.as_ptr() as *const u8, @@ -86,13 +90,11 @@ impl SpirvModule { ) }; let l0_module = match self.should_link_ptx_impl { - None => { - l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())) - } + None => l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())), Some(ptx_impl) => { l0::Module::build_link_spirv( ctx, - &dev, + dev, &[ptx_impl, byte_il], Some(self.build_options.as_c_str()), ) @@ -119,7 +121,7 @@ pub fn get_function( hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Vacant(entry) => { let new_module = CompiledModule { - base: module.spirv.compile(&mut device.l0_context, &device.base)?, + base: module.spirv.compile(&mut device.l0_context, device.base)?, kernels: HashMap::new(), }; entry.insert(new_module) @@ -135,7 +137,7 @@ pub fn get_function( std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes()) }) .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?; - let mut kernel = + let kernel = l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?; kernel.set_indirect_access( l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE @@ -165,7 +167,7 @@ pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result< pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> { let module = GlobalState::lock_current_context(|ctx| { let device = unsafe { &mut *ctx.device }; - let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?; + let l0_module = spirv_data.compile(&device.l0_context, device.base)?; let mut device_binaries = HashMap::new(); let compiled_module = CompiledModule { base: l0_module, diff --git a/zluda/src/impl/stream.rs b/zluda/src/impl/stream.rs index e212dfc..0fafe92 100644 --- a/zluda/src/impl/stream.rs +++ b/zluda/src/impl/stream.rs @@ -33,11 +33,11 @@ impl HasLivenessCookie for StreamData { pub struct StreamData { pub context: *mut ContextData, - pub queue: l0::CommandQueue, + pub queue: l0::CommandQueue<'static>, } impl StreamData { - pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> { + pub fn new_unitialized(ctx: &'static l0::Context, dev: l0::Device) -> Result<Self, CUresult> { Ok(StreamData { context: ptr::null_mut(), queue: l0::CommandQueue::new(ctx, dev)?, @@ -45,7 +45,7 @@ impl StreamData { } pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> { let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context; - let l0_dev = &unsafe { &*ctx.device }.base; + let l0_dev = unsafe { &*ctx.device }.base; Ok(StreamData { context: ctx as *mut _, queue: l0::CommandQueue::new(l0_ctx, l0_dev)?, @@ -55,7 +55,7 @@ impl StreamData { pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> { let ctx = unsafe { &mut *self.context }; let dev = unsafe { &mut *ctx.device }; - l0::CommandList::new(&mut dev.l0_context, &dev.base) + l0::CommandList::new(&mut dev.l0_context, dev.base) } } |