aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/function.rs
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-12-02 00:29:57 +0100
committerGitHub <[email protected]>2024-12-02 00:29:57 +0100
commit7a6df9dcbf59edef371e7f63c16c64916ddb0c0b (patch)
tree7800524ba25d38c514f1c769c9c1b665542c5500 /zluda/src/impl/function.rs
parent870fed4bb69d919a10822032d65ec20f385df9d7 (diff)
downloadZLUDA-7a6df9dcbf59edef371e7f63c16c64916ddb0c0b.tar.gz
ZLUDA-7a6df9dcbf59edef371e7f63c16c64916ddb0c0b.zip
Fix host code and update to CUDA 12.4 (#299)
Diffstat (limited to 'zluda/src/impl/function.rs')
-rw-r--r--zluda/src/impl/function.rs62
1 files changed, 41 insertions, 21 deletions
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs
index 7f35bb4..8d006ec 100644
--- a/zluda/src/impl/function.rs
+++ b/zluda/src/impl/function.rs
@@ -1,26 +1,46 @@
-use hip_runtime_sys::{hipError_t, hipFuncAttribute, hipFuncGetAttribute, hipFuncGetAttributes, hipFunction_attribute, hipLaunchKernel, hipModuleLaunchKernel};
-
-use super::{CUresult, HasLivenessCookie, LiveCheck};
-use crate::cuda::{CUfunction, CUfunction_attribute, CUstream};
-use ::std::os::raw::{c_uint, c_void};
-use std::{mem, ptr};
+use hip_runtime_sys::*;
pub(crate) fn get_attribute(
- pi: *mut i32,
- cu_attrib: CUfunction_attribute,
- func: CUfunction,
+ pi: &mut i32,
+ cu_attrib: hipFunction_attribute,
+ func: hipFunction_t,
+) -> hipError_t {
+ // TODO: implement HIP_FUNC_ATTRIBUTE_PTX_VERSION
+ // TODO: implement HIP_FUNC_ATTRIBUTE_BINARY_VERSION
+ unsafe { hipFuncGetAttribute(pi, cu_attrib, func) }?;
+ if cu_attrib == hipFunction_attribute::HIP_FUNC_ATTRIBUTE_NUM_REGS {
+ *pi = (*pi).max(1);
+ }
+ Ok(())
+}
+
+pub(crate) fn launch_kernel(
+ f: hipFunction_t,
+ grid_dim_x: ::core::ffi::c_uint,
+ grid_dim_y: ::core::ffi::c_uint,
+ grid_dim_z: ::core::ffi::c_uint,
+ block_dim_x: ::core::ffi::c_uint,
+ block_dim_y: ::core::ffi::c_uint,
+ block_dim_z: ::core::ffi::c_uint,
+ shared_mem_bytes: ::core::ffi::c_uint,
+ stream: hipStream_t,
+ kernel_params: *mut *mut ::core::ffi::c_void,
+ extra: *mut *mut ::core::ffi::c_void,
) -> hipError_t {
- if pi == ptr::null_mut() || func == ptr::null_mut() {
- return hipError_t::hipErrorInvalidValue;
+ // TODO: fix constants in extra
+ unsafe {
+ hipModuleLaunchKernel(
+ f,
+ grid_dim_x,
+ grid_dim_y,
+ grid_dim_z,
+ block_dim_x,
+ block_dim_y,
+ block_dim_z,
+ shared_mem_bytes,
+ stream,
+ kernel_params,
+ extra,
+ )
}
- let attrib = match cu_attrib {
- CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
- hipFunction_attribute::HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
- }
- CUfunction_attribute::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES => {
- hipFunction_attribute::HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
- }
- _ => return hipError_t::hipErrorInvalidValue,
- };
- unsafe { hipFuncGetAttribute(pi, attrib, func as _) }
}