aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/function.rs
blob: 8470620201f6b97f35a8cc187912f7886e1ba505 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
use hip_runtime_sys::{hipError_t, hipFuncGetAttributes};

use super::{CUresult, HasLivenessCookie, LiveCheck};
use crate::cuda::{CUfunction, CUfunction_attribute};
use ::std::os::raw::{c_uint, c_void};
use std::{mem, ptr};

pub(crate) fn get_attribute(
    pi: *mut i32,
    cu_attrib: CUfunction_attribute,
    func: CUfunction,
) -> hipError_t {
    if pi == ptr::null_mut() || func == ptr::null_mut() {
        return hipError_t::hipErrorInvalidValue;
    }
    let mut hip_attrib = unsafe { mem::zeroed() };
    let err = unsafe { hipFuncGetAttributes(&mut hip_attrib, func as _) };
    if err != hipError_t::hipSuccess {
        return err;
    }
    let value = match cu_attrib {
        CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => hip_attrib.maxThreadsPerBlock,
        CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES => hip_attrib.sharedSizeBytes as i32,
        _ => return hipError_t::hipErrorInvalidValue,
    };
    unsafe { *pi = value };
    hipError_t::hipSuccess
}