1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
use std::{
collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex,
};
use super::{function::Function, transmute_lifetime, CUresult};
use ptx;
pub type Module = Mutex<ModuleData>;
pub struct ModuleData {
base: l0::Module,
arg_lens: HashMap<CString, Vec<usize>>,
}
pub enum ModuleCompileError<'a> {
Parse(
Vec<ptx::ast::PtxError>,
Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
),
Compile(ptx::TranslateError),
L0(l0::sys::ze_result_t),
CUDA(CUresult),
}
impl<'a> ModuleCompileError<'a> {
pub fn get_build_log(&self) {}
}
impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> {
fn from(err: ptx::TranslateError) -> Self {
ModuleCompileError::Compile(err)
}
}
impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> {
fn from(err: l0::sys::ze_result_t) -> Self {
ModuleCompileError::L0(err)
}
}
impl<'a> From<CUresult> for ModuleCompileError<'a> {
fn from(err: CUresult) -> Self {
ModuleCompileError::CUDA(err)
}
}
impl ModuleData {
pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text);
let ast = match ast {
Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))),
Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
Ok(ast) => ast,
};
let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?;
let byte_il = unsafe {
slice::from_raw_parts::<u8>(
spirv.as_ptr() as *const _,
spirv.len() * mem::size_of::<u32>(),
)
};
let module = super::device::with_current_exclusive(|dev| {
l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
});
match module {
Ok((Ok(module), _)) => Ok(Mutex::new(Self {
base: module,
arg_lens: all_arg_lens
.into_iter()
.map(|(k, v)| (CString::new(k).unwrap(), v))
.collect(),
})),
Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
Err(err) => Err(ModuleCompileError::from(err)),
}
}
}
pub fn get_function(
hfunc: *mut *mut Function,
hmod: *mut Module,
name: *const c_char,
) -> Result<(), CUresult> {
if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
let name = unsafe { CStr::from_ptr(name) };
let (mut kernel, args_len) = unsafe { &*hmod }
.try_lock()
.map(|module| {
Result::<_, CUresult>::Ok((
l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?,
module
.arg_lens
.get(name)
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?
.clone(),
))
})
.map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
kernel.set_indirect_access(
l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
| l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED,
)?;
unsafe {
*hfunc = Box::into_raw(Box::new(Function {
base: kernel,
arg_size: args_len,
}))
};
Ok(())
}
pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> {
Ok(())
}
|