summaryrefslogtreecommitdiffhomepage
path: root/notcuda/src/impl/module.rs
diff options
context:
space:
mode:
Diffstat (limited to 'notcuda/src/impl/module.rs')
-rw-r--r--notcuda/src/impl/module.rs206
1 files changed, 129 insertions, 77 deletions
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs
index 35436c3..4422107 100644
--- a/notcuda/src/impl/module.rs
+++ b/notcuda/src/impl/module.rs
@@ -1,79 +1,90 @@
use std::{
- collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex,
+ collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem,
+ os::raw::c_char, ptr, slice,
};
-use super::{function::Function, transmute_lifetime, CUresult};
+use super::{
+ device, function::Function, function::FunctionData, CUresult, GlobalState, HasLivenessCookie,
+ LiveCheck,
+};
use ptx;
-pub type Module = Mutex<ModuleData>;
+pub type Module = LiveCheck<ModuleData>;
+
+impl HasLivenessCookie for ModuleData {
+ #[cfg(target_pointer_width = "64")]
+ const COOKIE: usize = 0xf1313bd46505f98a;
+
+ #[cfg(target_pointer_width = "32")]
+ const COOKIE: usize = 0xbdbe3f15;
+
+ const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
+
+ fn try_drop(&mut self) -> Result<(), CUresult> {
+ Ok(())
+ }
+}
pub struct ModuleData {
- base: l0::Module,
- arg_lens: HashMap<CString, Vec<usize>>,
+ pub spirv: SpirvModule,
+ // This should be a Vec<>, but I'm feeling lazy
+ pub device_binaries: HashMap<device::Index, CompiledModule>,
}
-pub enum ModuleCompileError<'a> {
- Parse(
- Vec<ptx::ast::PtxError>,
- Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
- ),
- Compile(ptx::TranslateError),
- L0(l0::sys::ze_result_t),
- CUDA(CUresult),
+pub struct SpirvModule {
+ pub binaries: Vec<u32>,
+ pub kernel_info: HashMap<String, ptx::KernelInfo>,
+ pub should_link_ptx_impl: Option<&'static [u8]>,
+ pub build_options: CString,
}
-impl<'a> ModuleCompileError<'a> {
- pub fn get_build_log(&self) {}
+pub struct CompiledModule {
+ pub base: l0::Module,
+ pub kernels: HashMap<CString, Box<Function>>,
}
-impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> {
- fn from(err: ptx::TranslateError) -> Self {
- ModuleCompileError::Compile(err)
+impl<L, T, E> From<ptx::ParseError<L, T, E>> for CUresult {
+ fn from(_: ptx::ParseError<L, T, E>) -> Self {
+ CUresult::CUDA_ERROR_INVALID_PTX
}
}
-impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> {
- fn from(err: l0::sys::ze_result_t) -> Self {
- ModuleCompileError::L0(err)
+impl From<ptx::TranslateError> for CUresult {
+ fn from(_: ptx::TranslateError) -> Self {
+ CUresult::CUDA_ERROR_INVALID_PTX
}
}
-impl<'a> From<CUresult> for ModuleCompileError<'a> {
- fn from(err: CUresult) -> Self {
- ModuleCompileError::CUDA(err)
+impl SpirvModule {
+ pub fn new_raw<'a>(text: *const c_char) -> Result<Self, CUresult> {
+ let u8_text = unsafe { CStr::from_ptr(text) };
+ let ptx_text = u8_text
+ .to_str()
+ .map_err(|_| CUresult::CUDA_ERROR_INVALID_PTX)?;
+ Self::new(ptx_text)
}
-}
-impl ModuleData {
- pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> {
+ pub fn new<'a>(ptx_text: &str) -> Result<Self, CUresult> {
let mut errors = Vec::new();
- let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text);
- let ast = match ast {
- Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))),
- Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
- Ok(ast) => ast,
- };
- let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?;
+ let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
+ let spirv_module = ptx::to_spirv_module(ast)?;
+ Ok(SpirvModule {
+ binaries: spirv_module.assemble(),
+ kernel_info: spirv_module.kernel_info,
+ should_link_ptx_impl: spirv_module.should_link_ptx_impl,
+ build_options: spirv_module.build_options,
+ })
+ }
+
+ pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> {
let byte_il = unsafe {
- slice::from_raw_parts::<u8>(
- spirv.as_ptr() as *const _,
- spirv.len() * mem::size_of::<u32>(),
+ slice::from_raw_parts(
+ self.binaries.as_ptr() as *const u8,
+ self.binaries.len() * mem::size_of::<u32>(),
)
};
- let module = super::device::with_current_exclusive(|dev| {
- l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
- });
- match module {
- Ok((Ok(module), _)) => Ok(Mutex::new(Self {
- base: module,
- arg_lens: all_arg_lens
- .into_iter()
- .map(|(k, v)| (CString::new(k).unwrap(), v))
- .collect(),
- })),
- Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
- Err(err) => Err(ModuleCompileError::from(err)),
- }
+ let l0_module = l0::Module::build_spirv(ctx, dev, byte_il, None).0?;
+ Ok(l0_module)
}
}
@@ -85,34 +96,75 @@ pub fn get_function(
if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let name = unsafe { CStr::from_ptr(name) };
- let (mut kernel, args_len) = unsafe { &*hmod }
- .try_lock()
- .map(|module| {
- Result::<_, CUresult>::Ok((
- l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?,
- module
- .arg_lens
- .get(name)
- .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?
- .clone(),
- ))
- })
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
- kernel.set_indirect_access(
- l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
- | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
- | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED,
- )?;
- unsafe {
- *hfunc = Box::into_raw(Box::new(Function {
- base: kernel,
- arg_size: args_len,
- }))
- };
+ let name = unsafe { CStr::from_ptr(name) }.to_owned();
+ let function: *mut Function = GlobalState::lock_current_context(|ctx| {
+ let module = unsafe { &mut *hmod }.as_result_mut()?;
+ let device = unsafe { &mut *ctx.device };
+ let compiled_module = match module.device_binaries.entry(device.index) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let new_module = CompiledModule {
+ base: module.spirv.compile(&mut device.l0_context, &device.base)?,
+ kernels: HashMap::new(),
+ };
+ entry.insert(new_module)
+ }
+ };
+ //let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) };
+ let kernel = match compiled_module.kernels.entry(name) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let kernel_info = module
+ .spirv
+ .kernel_info
+ .get(unsafe {
+ std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
+ })
+ .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
+ let kernel =
+ l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
+ entry.insert(Box::new(Function::new(FunctionData {
+ base: kernel,
+ arg_size: kernel_info.arguments_sizes.clone(),
+ use_shared_mem: kernel_info.uses_shared_mem,
+ })))
+ }
+ };
+ Ok::<_, CUresult>(kernel as *mut _)
+ })??;
+ unsafe { *hfunc = function };
Ok(())
}
-pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> {
+pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<(), CUresult> {
+ let spirv_data = SpirvModule::new_raw(image as *const _)?;
+ load_data_impl(pmod, spirv_data)
+}
+
+pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
+ let module = GlobalState::lock_current_context(|ctx| {
+ let device = unsafe { &mut *ctx.device };
+ let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?;
+ let mut device_binaries = HashMap::new();
+ let compiled_module = CompiledModule {
+ base: l0_module,
+ kernels: HashMap::new(),
+ };
+ device_binaries.insert(device.index, compiled_module);
+ let module_data = ModuleData {
+ spirv: spirv_data,
+ device_binaries,
+ };
+ Ok::<_, CUresult>(module_data)
+ })??;
+ let module_ptr = Box::into_raw(Box::new(Module::new(module)));
+ unsafe { *pmod = module_ptr };
Ok(())
}
+
+pub(crate) fn unload(module: *mut Module) -> Result<(), CUresult> {
+ if module == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ GlobalState::lock(|_| Module::destroy_impl(module))?
+}