aboutsummaryrefslogtreecommitdiffhomepage
path: root/notcuda/src/impl/function.rs
diff options
context:
space:
mode:
Diffstat (limited to 'notcuda/src/impl/function.rs')
-rw-r--r--notcuda/src/impl/function.rs40
1 files changed, 38 insertions, 2 deletions
diff --git a/notcuda/src/impl/function.rs b/notcuda/src/impl/function.rs
index 394f806..27bf9b6 100644
--- a/notcuda/src/impl/function.rs
+++ b/notcuda/src/impl/function.rs
@@ -1,7 +1,9 @@
use ::std::os::raw::{c_uint, c_void};
-use std::ptr;
+use std::{hint, ptr};
-use super::{CUresult, GlobalState, HasLivenessCookie, LiveCheck, stream::Stream};
+use crate::cuda::CUfunction_attribute;
+
+use super::{stream::Stream, CUresult, GlobalState, HasLivenessCookie, LiveCheck};
pub type Function = LiveCheck<FunctionData>;
@@ -23,6 +25,19 @@ pub struct FunctionData {
pub base: l0::Kernel<'static>,
pub arg_size: Vec<usize>,
pub use_shared_mem: bool,
+ pub properties: Option<Box<l0::sys::ze_kernel_properties_t>>,
+}
+
+impl FunctionData {
+ fn get_properties(&mut self) -> Result<&l0::sys::ze_kernel_properties_t, l0::sys::ze_result_t> {
+ if let None = self.properties {
+ self.properties = Some(self.base.get_properties()?)
+ }
+ match self.properties {
+ Some(ref props) => Ok(props.as_ref()),
+ None => unsafe { hint::unreachable_unchecked() },
+ }
+ }
}
pub fn launch_kernel(
@@ -74,3 +89,24 @@ pub fn launch_kernel(
Ok(())
})?
}
+
+pub(crate) fn get_attribute(
+ pi: *mut i32,
+ attrib: CUfunction_attribute,
+ func: *mut Function,
+) -> Result<(), CUresult> {
+ if pi == ptr::null_mut() || func == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ match attrib {
+ CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
+ let max_threads = GlobalState::lock_function(func, |func| {
+ let props = func.get_properties()?;
+ Ok::<_, CUresult>(props.maxSubgroupSize * props.maxNumSubgroups)
+ })??;
+ unsafe { *pi = max_threads as i32 };
+ Ok(())
+ }
+ _ => Err(CUresult::CUDA_ERROR_NOT_SUPPORTED),
+ }
+}