diff options
Diffstat (limited to 'zluda/src/impl/module.rs')
-rw-r--r-- | zluda/src/impl/module.rs | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs new file mode 100644 index 0000000..cba030e --- /dev/null +++ b/zluda/src/impl/module.rs @@ -0,0 +1,188 @@ +use std::{ + collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, + os::raw::c_char, ptr, slice, +}; + +use super::{ + device, function::Function, function::FunctionData, CUresult, GlobalState, HasLivenessCookie, + LiveCheck, +}; +use ptx; + +pub type Module = LiveCheck<ModuleData>; + +impl HasLivenessCookie for ModuleData { + #[cfg(target_pointer_width = "64")] + const COOKIE: usize = 0xf1313bd46505f98a; + + #[cfg(target_pointer_width = "32")] + const COOKIE: usize = 0xbdbe3f15; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE; + + fn try_drop(&mut self) -> Result<(), CUresult> { + Ok(()) + } +} + +pub struct ModuleData { + pub spirv: SpirvModule, + // This should be a Vec<>, but I'm feeling lazy + pub device_binaries: HashMap<device::Index, CompiledModule>, +} + +pub struct SpirvModule { + pub binaries: Vec<u32>, + pub kernel_info: HashMap<String, ptx::KernelInfo>, + pub should_link_ptx_impl: Option<&'static [u8]>, + pub build_options: CString, +} + +pub struct CompiledModule { + pub base: l0::Module, + pub kernels: HashMap<CString, Box<Function>>, +} + +impl<L, T, E> From<ptx::ParseError<L, T, E>> for CUresult { + fn from(_: ptx::ParseError<L, T, E>) -> Self { + CUresult::CUDA_ERROR_INVALID_PTX + } +} + +impl From<ptx::TranslateError> for CUresult { + fn from(_: ptx::TranslateError) -> Self { + CUresult::CUDA_ERROR_INVALID_PTX + } +} + +impl SpirvModule { + pub fn new_raw<'a>(text: *const c_char) -> Result<Self, CUresult> { + let u8_text = unsafe { CStr::from_ptr(text) }; + let ptx_text = u8_text + .to_str() + .map_err(|_| CUresult::CUDA_ERROR_INVALID_PTX)?; + Self::new(ptx_text) + } + + pub fn new<'a>(ptx_text: &str) -> Result<Self, CUresult> { + let mut errors = Vec::new(); + let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?; + let spirv_module = ptx::to_spirv_module(ast)?; + Ok(SpirvModule { + binaries: spirv_module.assemble(), + kernel_info: spirv_module.kernel_info, + should_link_ptx_impl: spirv_module.should_link_ptx_impl, + build_options: spirv_module.build_options, + }) + } + + pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> { + let byte_il = unsafe { + slice::from_raw_parts( + self.binaries.as_ptr() as *const u8, + self.binaries.len() * mem::size_of::<u32>(), + ) + }; + let l0_module = match self.should_link_ptx_impl { + None => { + l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())).0 + } + Some(ptx_impl) => { + l0::Module::build_link_spirv( + ctx, + &dev, + &[ptx_impl, byte_il], + Some(self.build_options.as_c_str()), + ) + .0 + } + }; + Ok(l0_module?) + } +} + +pub fn get_function( + hfunc: *mut *mut Function, + hmod: *mut Module, + name: *const c_char, +) -> Result<(), CUresult> { + if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + let name = unsafe { CStr::from_ptr(name) }.to_owned(); + let function: *mut Function = GlobalState::lock_current_context(|ctx| { + let module = unsafe { &mut *hmod }.as_result_mut()?; + let device = unsafe { &mut *ctx.device }; + let compiled_module = match module.device_binaries.entry(device.index) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let new_module = CompiledModule { + base: module.spirv.compile(&mut device.l0_context, &device.base)?, + kernels: HashMap::new(), + }; + entry.insert(new_module) + } + }; + let kernel = match compiled_module.kernels.entry(name) { + hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(), + hash_map::Entry::Vacant(entry) => { + let kernel_info = module + .spirv + .kernel_info + .get(unsafe { + std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes()) + }) + .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?; + let mut kernel = + l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?; + kernel.set_indirect_access( + l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE + | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST + | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED + )?; + entry.insert(Box::new(Function::new(FunctionData { + base: kernel, + arg_size: kernel_info.arguments_sizes.clone(), + use_shared_mem: kernel_info.uses_shared_mem, + properties: None, + }))) + } + }; + Ok::<_, CUresult>(kernel as *mut _) + })??; + unsafe { *hfunc = function }; + Ok(()) +} + +pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<(), CUresult> { + let spirv_data = SpirvModule::new_raw(image as *const _)?; + load_data_impl(pmod, spirv_data) +} + +pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> { + let module = GlobalState::lock_current_context(|ctx| { + let device = unsafe { &mut *ctx.device }; + let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?; + let mut device_binaries = HashMap::new(); + let compiled_module = CompiledModule { + base: l0_module, + kernels: HashMap::new(), + }; + device_binaries.insert(device.index, compiled_module); + let module_data = ModuleData { + spirv: spirv_data, + device_binaries, + }; + Ok::<_, CUresult>(module_data) + })??; + let module_ptr = Box::into_raw(Box::new(Module::new(module))); + unsafe { *pmod = module_ptr }; + Ok(()) +} + +pub(crate) fn unload(module: *mut Module) -> Result<(), CUresult> { + if module == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| Module::destroy_impl(module))? +} |