summaryrefslogtreecommitdiffhomepage
path: root/notcuda/src/impl/module.rs
blob: 35436c31edfc53b500740ede80796ddc046b34f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
use std::{
    collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex,
};

use super::{function::Function, transmute_lifetime, CUresult};
use ptx;

pub type Module = Mutex<ModuleData>;

pub struct ModuleData {
    base: l0::Module,
    arg_lens: HashMap<CString, Vec<usize>>,
}

pub enum ModuleCompileError<'a> {
    Parse(
        Vec<ptx::ast::PtxError>,
        Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
    ),
    Compile(ptx::TranslateError),
    L0(l0::sys::ze_result_t),
    CUDA(CUresult),
}

impl<'a> ModuleCompileError<'a> {
    pub fn get_build_log(&self) {}
}

impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> {
    fn from(err: ptx::TranslateError) -> Self {
        ModuleCompileError::Compile(err)
    }
}

impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> {
    fn from(err: l0::sys::ze_result_t) -> Self {
        ModuleCompileError::L0(err)
    }
}

impl<'a> From<CUresult> for ModuleCompileError<'a> {
    fn from(err: CUresult) -> Self {
        ModuleCompileError::CUDA(err)
    }
}

impl ModuleData {
    pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> {
        let mut errors = Vec::new();
        let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text);
        let ast = match ast {
            Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))),
            Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
            Ok(ast) => ast,
        };
        let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?;
        let byte_il = unsafe {
            slice::from_raw_parts::<u8>(
                spirv.as_ptr() as *const _,
                spirv.len() * mem::size_of::<u32>(),
            )
        };
        let module = super::device::with_current_exclusive(|dev| {
            l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
        });
        match module {
            Ok((Ok(module), _)) => Ok(Mutex::new(Self {
                base: module,
                arg_lens: all_arg_lens
                    .into_iter()
                    .map(|(k, v)| (CString::new(k).unwrap(), v))
                    .collect(),
            })),
            Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
            Err(err) => Err(ModuleCompileError::from(err)),
        }
    }
}

pub fn get_function(
    hfunc: *mut *mut Function,
    hmod: *mut Module,
    name: *const c_char,
) -> Result<(), CUresult> {
    if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() {
        return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
    }
    let name = unsafe { CStr::from_ptr(name) };
    let (mut kernel, args_len) = unsafe { &*hmod }
        .try_lock()
        .map(|module| {
            Result::<_, CUresult>::Ok((
                l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?,
                module
                    .arg_lens
                    .get(name)
                    .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?
                    .clone(),
            ))
        })
        .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
    kernel.set_indirect_access(
        l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
            | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
            | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED,
    )?;
    unsafe {
        *hfunc = Box::into_raw(Box::new(Function {
            base: kernel,
            arg_size: args_len,
        }))
    };
    Ok(())
}

pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> {
    Ok(())
}