aboutsummaryrefslogtreecommitdiffhomepage
path: root/notcuda/src/impl/module.rs
blob: 06d050d8557cfc196f62f8549a9e68d23f8ccc34 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use std::{ffi::c_void, ffi::CStr, mem, os::raw::c_char, ptr, slice, sync::Mutex};

use super::{transmute_lifetime, CUresult};
use ptx;

use super::context;

pub type Module = Mutex<ModuleData>;

pub struct ModuleData {
    base: l0::Module,
}

pub enum ModuleCompileError<'a> {
    Parse(
        Vec<ptx::ast::PtxError>,
        Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
    ),
    Compile(ptx::TranslateError),
    L0(l0::sys::ze_result_t),
    CUDA(CUresult),
}

impl<'a> ModuleCompileError<'a> {
    pub fn get_build_log(&self) {}
}

impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> {
    fn from(err: ptx::TranslateError) -> Self {
        ModuleCompileError::Compile(err)
    }
}

impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> {
    fn from(err: l0::sys::ze_result_t) -> Self {
        ModuleCompileError::L0(err)
    }
}

impl<'a> From<CUresult> for ModuleCompileError<'a> {
    fn from(err: CUresult) -> Self {
        ModuleCompileError::CUDA(err)
    }
}

impl ModuleData {
    pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> {
        let mut errors = Vec::new();
        let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text);
        let ast = match ast {
            Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))),
            Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
            Ok(ast) => ast,
        };
        let spirv = ptx::to_spirv(ast)?;
        let byte_il = unsafe {
            slice::from_raw_parts::<u8>(
                spirv.as_ptr() as *const _,
                spirv.len() * mem::size_of::<u32>(),
            )
        };
        let module = super::device::with_current_exclusive(|dev| {
            l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
        });
        match module {
            Ok(Ok(module)) => Ok(Mutex::new(Self { base: module })),
            Ok(Err(err)) => Err(ModuleCompileError::from(err)),
            Err(err) => Err(ModuleCompileError::from(err)),
        }
    }
}

pub struct Function {
    base: l0::Kernel<'static>,
}

pub fn get_function(
    hfunc: *mut *mut Function,
    hmod: *mut Module,
    name: *const c_char,
) -> Result<(), CUresult> {
    if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() {
        return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
    }
    let name = unsafe { CStr::from_ptr(name) };
    let kernel = unsafe { &*hmod }
        .try_lock()
        .map(|module| l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name))
        .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
    unsafe { *hfunc = Box::into_raw(Box::new(Function { base: kernel })) };
    Ok(())
}