diff options
author | Andrzej Janik <[email protected]> | 2024-11-25 06:17:14 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-11-25 06:17:14 +0100 |
commit | 502b0c957e1fb58f5b6df678a26b8758349f8eb4 (patch) | |
tree | 9f9d22a9293c08090f54f0f65681a8cb4e86bfc8 | |
parent | c461cefd7d57edd430d74780e90d25859f3b7472 (diff) | |
download | ZLUDA-502b0c957e1fb58f5b6df678a26b8758349f8eb4.tar.gz ZLUDA-502b0c957e1fb58f5b6df678a26b8758349f8eb4.zip |
Add more missing host-side code
-rw-r--r-- | cuda_base/src/lib.rs | 38 | ||||
-rw-r--r-- | zluda/Cargo.toml | 1 | ||||
-rw-r--r-- | zluda/src/impl/context.rs | 80 | ||||
-rw-r--r-- | zluda/src/impl/device.rs | 30 | ||||
-rw-r--r-- | zluda/src/impl/driver.rs | 79 | ||||
-rw-r--r-- | zluda/src/impl/function.rs | 62 | ||||
-rw-r--r-- | zluda/src/impl/link.rs | 86 | ||||
-rw-r--r-- | zluda/src/impl/memory.rs | 70 | ||||
-rw-r--r-- | zluda/src/impl/mod.rs | 48 | ||||
-rw-r--r-- | zluda/src/impl/pointer.rs | 39 | ||||
-rw-r--r-- | zluda/src/impl/test.rs | 157 | ||||
-rw-r--r-- | zluda/src/lib.rs | 27 |
12 files changed, 347 insertions, 370 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs index 0cc1f53..833d372 100644 --- a/cuda_base/src/lib.rs +++ b/cuda_base/src/lib.rs @@ -150,6 +150,10 @@ impl VisitMut for FixFnSignatures { } } +const MODULES: &[&str] = &[ + "context", "device", "driver", "function", "link", "memory", "module", "pointer", +]; + #[proc_macro] pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { let mut path = parse_macro_input!(tokens as syn::Path); @@ -161,8 +165,9 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream { .0 .ident .to_string(); + let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string()); let segments: Vec<String> = split(&fn_[2..]); // skip "cu" - let fn_path = join(segments); + let fn_path = join(segments, !already_has_module); quote! { #path #fn_path } @@ -181,23 +186,16 @@ fn split(fn_: &str) -> Vec<String> { result } -fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> { +fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> { fn full_form(segment: &str) -> Option<&[&str]> { Some(match segment { "ctx" => &["context"], + "func" => &["function"], + "mem" => &["memory"], "memcpy" => &["memory", "copy"], _ => return None, }) } - const MODULES: &[&str] = &[ - "context", - "device", - "function", - "link", - "memory", - "module", - "pointer" - ]; let mut normalized: Vec<&str> = Vec::new(); for segment in fn_.iter() { match full_form(segment) { @@ -205,18 +203,20 @@ fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> { None => normalized.push(&*segment), } } + if !find_module { + return [Ident::new(&normalized.join("_"), Span::call_site())] + .into_iter() + .collect(); + } if !MODULES.contains(&normalized[0]) { - let mut globalized = vec!["global"]; + let mut globalized = vec!["driver"]; globalized.extend(normalized); normalized = globalized; } let (module, path) = normalized.split_first().unwrap(); let path = path.join("_"); - let mut result = Punctuated::new(); - result.extend( - [module, &&*path] - .into_iter() - .map(|s| Ident::new(s, Span::call_site())), - ); - result + [module, &&*path] + .into_iter() + .map(|s| Ident::new(s, Span::call_site())) + .collect() } diff --git a/zluda/Cargo.toml b/zluda/Cargo.toml index 0a4c406..14e9845 100644 --- a/zluda/Cargo.toml +++ b/zluda/Cargo.toml @@ -20,6 +20,7 @@ num_enum = "0.4" lz4-sys = "1.9" tempfile = "3" paste = "1.0" +rustc-hash = "1.1" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["heapapi", "std"] } diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index d1a135f..973febc 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -1,4 +1,46 @@ +use super::{driver, FromCuda, ZludaObject}; +use cuda_types::*; use hip_runtime_sys::*; +use rustc_hash::FxHashSet; +use std::{cell::RefCell, ptr, sync::Mutex}; + +thread_local! { + pub(crate) static CONTEXT_STACK: RefCell<Vec<(CUcontext, hipDevice_t)>> = RefCell::new(Vec::new()); +} + +pub(crate) struct Context { + pub(crate) device: hipDevice_t, + pub(crate) mutable: Mutex<OwnedByContext>, +} + +pub(crate) struct OwnedByContext { + pub(crate) ref_count: usize, // only used by primary context + pub(crate) _memory: FxHashSet<hipDeviceptr_t>, + pub(crate) _streams: FxHashSet<hipStream_t>, + pub(crate) _modules: FxHashSet<CUmodule>, +} + +impl ZludaObject for Context { + const COOKIE: usize = 0x5f867c6d9cb73315; + + type CudaHandle = CUcontext; + + fn drop_checked(&mut self) -> CUresult { + Ok(()) + } +} + +pub(crate) fn new(device: hipDevice_t) -> Context { + Context { + device, + mutable: Mutex::new(OwnedByContext { + ref_count: 0, + _memory: FxHashSet::default(), + _streams: FxHashSet::default(), + _modules: FxHashSet::default(), + }), + } +} pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: hipLimit_t) -> hipError_t { unsafe { hipDeviceGetLimit(pvalue, limit) } @@ -11,3 +53,41 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> hipError_t { pub(crate) fn synchronize() -> hipError_t { unsafe { hipDeviceSynchronize() } } + +pub(crate) fn get_primary(hip_dev: hipDevice_t) -> Result<(&'static Context, CUcontext), CUerror> { + let dev = driver::device(hip_dev)?; + Ok(dev.primary_context()) +} + +pub(crate) fn set_current(raw_ctx: CUcontext) -> CUresult { + let new_device = if raw_ctx.0 == ptr::null_mut() { + CONTEXT_STACK.with(|stack| { + let mut stack = stack.borrow_mut(); + if let Some((_, old_device)) = stack.pop() { + if let Some((_, new_device)) = stack.last() { + if old_device != *new_device { + return Some(*new_device); + } + } + } + None + }) + } else { + let ctx: &Context = FromCuda::from_cuda(&raw_ctx)?; + let device = ctx.device; + CONTEXT_STACK.with(move |stack| { + let mut stack = stack.borrow_mut(); + let last_device = stack.last().map(|(_, dev)| *dev); + stack.push((raw_ctx, device)); + match last_device { + None => Some(device), + Some(last_device) if last_device != device => Some(device), + _ => None, + } + }) + }; + if let Some(dev) = new_device { + unsafe { hipSetDevice(dev)? }; + } + Ok(()) +} diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index a2a56c9..8836c1e 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -2,6 +2,8 @@ use cuda_types::*; use hip_runtime_sys::*; use std::{mem, ptr}; +use super::context; + const PROJECT_SUFFIX: &[u8] = b" [ZLUDA]\0"; pub const COMPUTE_CAPABILITY_MAJOR: i32 = 8; pub const COMPUTE_CAPABILITY_MINOR: i32 = 8; @@ -307,3 +309,31 @@ pub(crate) fn get_count(count: &mut ::core::ffi::c_int) -> hipError_t { fn clamp_usize(x: usize) -> i32 { usize::min(x, i32::MAX as usize) as i32 } + +pub(crate) fn primary_context_retain( + pctx: &mut CUcontext, + hip_dev: hipDevice_t, +) -> Result<(), CUerror> { + let (ctx, raw_ctx) = context::get_primary(hip_dev)?; + { + let mut mutable_ctx = ctx.mutable.lock().map_err(|_| CUerror::UNKNOWN)?; + mutable_ctx.ref_count += 1; + } + *pctx = raw_ctx; + Ok(()) +} + +pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> Result<(), CUerror> { + let (ctx, _) = context::get_primary(hip_dev)?; + { + let mut mutable_ctx = ctx.mutable.lock().map_err(|_| CUerror::UNKNOWN)?; + if mutable_ctx.ref_count == 0 { + return Err(CUerror::INVALID_CONTEXT); + } + mutable_ctx.ref_count -= 1; + if mutable_ctx.ref_count == 0 { + // TODO: drop all children + } + } + Ok(()) +} diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs new file mode 100644 index 0000000..7ff2f54 --- /dev/null +++ b/zluda/src/impl/driver.rs @@ -0,0 +1,79 @@ +use cuda_types::*;
+use hip_runtime_sys::*;
+use std::{
+ ffi::{CStr, CString},
+ mem, slice,
+ sync::OnceLock,
+};
+
+use crate::r#impl::context;
+
+use super::LiveCheck;
+
+pub(crate) struct GlobalState {
+ pub devices: Vec<Device>,
+}
+
+pub(crate) struct Device {
+ pub(crate) _comgr_isa: CString,
+ primary_context: LiveCheck<context::Context>,
+}
+
+impl Device {
+ pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
+ unsafe {
+ (
+ self.primary_context.data.assume_init_ref(),
+ self.primary_context.as_handle(),
+ )
+ }
+ }
+}
+
+pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
+ global_state()?
+ .devices
+ .get(dev as usize)
+ .ok_or(CUerror::INVALID_DEVICE)
+}
+
+pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
+ static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
+ fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
+ unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
+ }
+ GLOBAL_STATE
+ .get_or_init(|| {
+ let mut device_count = 0;
+ unsafe { hipGetDeviceCount(&mut device_count) }?;
+ Ok(GlobalState {
+ devices: (0..device_count)
+ .map(|i| {
+ let mut props = unsafe { mem::zeroed() };
+ unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
+ Ok::<_, CUerror>(Device {
+ _comgr_isa: CStr::from_bytes_until_nul(cast_slice(
+ &props.gcnArchName[..],
+ ))
+ .map_err(|_| CUerror::UNKNOWN)?
+ .to_owned(),
+ primary_context: LiveCheck::new(context::new(i)),
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?,
+ })
+ })
+ .as_ref()
+ .map_err(|e| *e)
+}
+
+pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
+ unsafe { hipInit(flags) }?;
+ global_state()?;
+ Ok(())
+}
+
+pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
+ *version = cuda_types::CUDA_VERSION as i32;
+ Ok(())
+}
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs index 7f35bb4..8d006ec 100644 --- a/zluda/src/impl/function.rs +++ b/zluda/src/impl/function.rs @@ -1,26 +1,46 @@ -use hip_runtime_sys::{hipError_t, hipFuncAttribute, hipFuncGetAttribute, hipFuncGetAttributes, hipFunction_attribute, hipLaunchKernel, hipModuleLaunchKernel}; - -use super::{CUresult, HasLivenessCookie, LiveCheck}; -use crate::cuda::{CUfunction, CUfunction_attribute, CUstream}; -use ::std::os::raw::{c_uint, c_void}; -use std::{mem, ptr}; +use hip_runtime_sys::*; pub(crate) fn get_attribute( - pi: *mut i32, - cu_attrib: CUfunction_attribute, - func: CUfunction, + pi: &mut i32, + cu_attrib: hipFunction_attribute, + func: hipFunction_t, +) -> hipError_t { + // TODO: implement HIP_FUNC_ATTRIBUTE_PTX_VERSION + // TODO: implement HIP_FUNC_ATTRIBUTE_BINARY_VERSION + unsafe { hipFuncGetAttribute(pi, cu_attrib, func) }?; + if cu_attrib == hipFunction_attribute::HIP_FUNC_ATTRIBUTE_NUM_REGS { + *pi = (*pi).max(1); + } + Ok(()) +} + +pub(crate) fn launch_kernel( + f: hipFunction_t, + grid_dim_x: ::core::ffi::c_uint, + grid_dim_y: ::core::ffi::c_uint, + grid_dim_z: ::core::ffi::c_uint, + block_dim_x: ::core::ffi::c_uint, + block_dim_y: ::core::ffi::c_uint, + block_dim_z: ::core::ffi::c_uint, + shared_mem_bytes: ::core::ffi::c_uint, + stream: hipStream_t, + kernel_params: *mut *mut ::core::ffi::c_void, + extra: *mut *mut ::core::ffi::c_void, ) -> hipError_t { - if pi == ptr::null_mut() || func == ptr::null_mut() { - return hipError_t::hipErrorInvalidValue; + // TODO: fix constants in extra + unsafe { + hipModuleLaunchKernel( + f, + grid_dim_x, + grid_dim_y, + grid_dim_z, + block_dim_x, + block_dim_y, + block_dim_z, + shared_mem_bytes, + stream, + kernel_params, + extra, + ) } - let attrib = match cu_attrib { - CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => { - hipFunction_attribute::HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK - } - CUfunction_attribute::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES => { - hipFunction_attribute::HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES - } - _ => return hipError_t::hipErrorInvalidValue, - }; - unsafe { hipFuncGetAttribute(pi, attrib, func as _) } } diff --git a/zluda/src/impl/link.rs b/zluda/src/impl/link.rs deleted file mode 100644 index d66608f..0000000 --- a/zluda/src/impl/link.rs +++ /dev/null @@ -1,86 +0,0 @@ -use std::{ - ffi::{c_void, CStr}, - mem, ptr, slice, -}; - -use hip_runtime_sys::{hipCtxGetDevice, hipError_t, hipGetDeviceProperties}; - -use crate::{ - cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult}, - hip_call, -}; - -use super::module::{self, SpirvModule}; - -struct LinkState { - modules: Vec<SpirvModule>, - result: Option<Vec<u8>>, -} - -pub(crate) unsafe fn create( - num_options: u32, - options: *mut CUjit_option, - option_values: *mut *mut c_void, - state_out: *mut CUlinkState, -) -> CUresult { - if state_out == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; - } - let state = Box::new(LinkState { - modules: Vec::new(), - result: None, - }); - *state_out = mem::transmute(state); - CUresult::CUDA_SUCCESS -} - -pub(crate) unsafe fn add_data( - state: CUlinkState, - type_: CUjitInputType, - data: *mut c_void, - size: usize, - name: *const i8, - num_options: u32, - options: *mut CUjit_option, - option_values: *mut *mut c_void, -) -> Result<(), hipError_t> { - if state == ptr::null_mut() { - return Err(hipError_t::hipErrorInvalidValue); - } - let state: *mut LinkState = mem::transmute(state); - let state = &mut *state; - // V-RAY specific hack - if state.modules.len() == 2 { - return Err(hipError_t::hipSuccess); - } - let spirv_data = SpirvModule::new_raw(data as *const _)?; - state.modules.push(spirv_data); - Ok(()) -} - -pub(crate) unsafe fn complete( - state: CUlinkState, - cubin_out: *mut *mut c_void, - size_out: *mut usize, -) -> Result<(), hipError_t> { - let mut dev = 0; - hip_call! { hipCtxGetDevice(&mut dev) }; - let mut props = unsafe { mem::zeroed() }; - hip_call! { hipGetDeviceProperties(&mut props, dev) }; - let state: &mut LinkState = mem::transmute(state); - let spirv_bins = state.modules.iter().map(|m| &m.binaries[..]); - let should_link_ptx_impl = state.modules.iter().find_map(|m| m.should_link_ptx_impl); - let mut arch_binary = module::compile_amd(&props, spirv_bins, should_link_ptx_impl) - .map_err(|_| hipError_t::hipErrorUnknown)?; - let ptr = arch_binary.as_mut_ptr(); - let size = arch_binary.len(); - state.result = Some(arch_binary); - *cubin_out = ptr as _; - *size_out = size; - Ok(()) -} - -pub(crate) unsafe fn destroy(state: CUlinkState) -> CUresult { - let state: Box<LinkState> = mem::transmute(state); - CUresult::CUDA_SUCCESS -} diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 6041623..b23afa9 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -1,55 +1,25 @@ -use hip_runtime_sys::{ - hipDrvMemcpy3D, hipError_t, hipMemcpy3D, hipMemcpy3DParms, hipMemoryType, hipPitchedPtr, - hipPos, HIP_MEMCPY3D, -}; -use std::ptr; +use hip_runtime_sys::*; -use crate::{ - cuda::{CUDA_MEMCPY3D_st, CUdeviceptr, CUmemorytype, CUresult}, - hip_call, -}; +pub(crate) fn alloc_v2(dptr: *mut hipDeviceptr_t, bytesize: usize) -> hipError_t { + unsafe { hipMalloc(dptr.cast(), bytesize) } +} + +pub(crate) fn free_v2(dptr: hipDeviceptr_t) -> hipError_t { + unsafe { hipFree(dptr.0) } +} -// TODO change HIP impl to 64 bits -pub(crate) unsafe fn copy_3d(cu_copy: *const CUDA_MEMCPY3D_st) -> Result<(), hipError_t> { - if cu_copy == ptr::null() { - return Err(hipError_t::hipErrorInvalidValue); - } - let cu_copy = *cu_copy; - let hip_copy = HIP_MEMCPY3D { - srcXInBytes: cu_copy.srcXInBytes as u32, - srcY: cu_copy.srcY as u32, - srcZ: cu_copy.srcZ as u32, - srcLOD: cu_copy.srcLOD as u32, - srcMemoryType: memory_type(cu_copy.srcMemoryType)?, - srcHost: cu_copy.srcHost, - srcDevice: cu_copy.srcDevice.0 as _, - srcArray: cu_copy.srcArray as _, - srcPitch: cu_copy.srcPitch as u32, - srcHeight: cu_copy.srcHeight as u32, - dstXInBytes: cu_copy.dstXInBytes as u32, - dstY: cu_copy.dstY as u32, - dstZ: cu_copy.dstZ as u32, - dstLOD: cu_copy.dstLOD as u32, - dstMemoryType: memory_type(cu_copy.dstMemoryType)?, - dstHost: cu_copy.dstHost, - dstDevice: cu_copy.dstDevice.0 as _, - dstArray: cu_copy.dstArray as _, - dstPitch: cu_copy.dstPitch as u32, - dstHeight: cu_copy.dstHeight as u32, - WidthInBytes: cu_copy.WidthInBytes as u32, - Height: cu_copy.Height as u32, - Depth: cu_copy.Depth as u32, - }; - hip_call! { hipDrvMemcpy3D(&hip_copy) }; - Ok(()) +pub(crate) fn copy_dto_h_v2( + dst_host: *mut ::core::ffi::c_void, + src_device: hipDeviceptr_t, + byte_count: usize, +) -> hipError_t { + unsafe { hipMemcpyDtoH(dst_host, src_device, byte_count) } } -pub(crate) fn memory_type(cu: CUmemorytype) -> Result<hipMemoryType, hipError_t> { - match cu { - CUmemorytype::CU_MEMORYTYPE_HOST => Ok(hipMemoryType::hipMemoryTypeHost), - CUmemorytype::CU_MEMORYTYPE_DEVICE => Ok(hipMemoryType::hipMemoryTypeDevice), - CUmemorytype::CU_MEMORYTYPE_ARRAY => Ok(hipMemoryType::hipMemoryTypeArray), - CUmemorytype::CU_MEMORYTYPE_UNIFIED => Ok(hipMemoryType::hipMemoryTypeUnified), - _ => Err(hipError_t::hipErrorInvalidValue), - } +pub(crate) fn copy_hto_d_v2( + dst_device: hipDeviceptr_t, + src_host: *const ::core::ffi::c_void, + byte_count: usize, +) -> hipError_t { + unsafe { hipMemcpyHtoD(dst_device, src_host.cast_mut(), byte_count) } } diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 0400006..7b4afc5 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -1,10 +1,14 @@ use cuda_types::*; use hip_runtime_sys::*; -use std::mem::{self, ManuallyDrop}; +use std::mem::{self, ManuallyDrop, MaybeUninit}; pub(super) mod context; pub(super) mod device; +pub(super) mod driver; +pub(super) mod function; +pub(super) mod memory; pub(super) mod module; +pub(super) mod pointer; #[cfg(debug_assertions)] pub(crate) fn unimplemented() -> CUresult { @@ -97,9 +101,12 @@ macro_rules! from_cuda_object { from_cuda_nop!( *mut i8, + *mut i32, *mut usize, - *const std::ffi::c_void, + *const ::core::ffi::c_void, *const ::core::ffi::c_char, + *mut ::core::ffi::c_void, + *mut *mut ::core::ffi::c_void, i32, u32, usize, @@ -107,11 +114,14 @@ from_cuda_nop!( CUdevice_attribute ); from_cuda_transmute!( - CUdevice => hipDevice_t, CUuuid => hipUUID, - CUfunction => hipFunction_t + CUfunction => hipFunction_t, + CUfunction_attribute => hipFunction_attribute, + CUstream => hipStream_t, + CUpointer_attribute => hipPointer_attribute, + CUdeviceptr_v2 => hipDeviceptr_t ); -from_cuda_object!(module::Module); +from_cuda_object!(module::Module, context::Context); impl<'a> FromCuda<'a, CUlimit> for hipLimit_t { fn from_cuda(limit: &'a CUlimit) -> Result<Self, CUerror> { @@ -140,20 +150,28 @@ pub(crate) trait ZludaObject: Sized + Send + Sync { #[repr(C)] pub(crate) struct LiveCheck<T: ZludaObject> { cookie: usize, - data: ManuallyDrop<T>, + data: MaybeUninit<T>, } impl<T: ZludaObject> LiveCheck<T> { - fn wrap(data: T) -> *mut Self { - Box::into_raw(Box::new(LiveCheck { + fn new(data: T) -> Self { + LiveCheck { cookie: T::COOKIE, - data: ManuallyDrop::new(data), - })) + data: MaybeUninit::new(data), + } + } + + fn as_handle(&self) -> T::CudaHandle { + unsafe { mem::transmute_copy(self) } + } + + fn wrap(data: T) -> *mut Self { + Box::into_raw(Box::new(Self::new(data))) } fn as_result(&self) -> Result<&T, CUerror> { if self.cookie == T::COOKIE { - Ok(&self.data) + Ok(unsafe { self.data.assume_init_ref() }) } else { Err(T::LIVENESS_FAIL) } @@ -167,8 +185,8 @@ impl<T: ZludaObject> LiveCheck<T> { fn drop_checked(&mut self) -> Result<Result<(), CUerror>, CUerror> { if self.cookie == T::COOKIE { self.cookie = 0; - let result = self.data.drop_checked(); - unsafe { ManuallyDrop::drop(&mut self.data) }; + let result = unsafe { self.data.assume_init_mut().drop_checked() }; + unsafe { MaybeUninit::assume_init_drop(&mut self.data) }; Ok(result) } else { Err(T::LIVENESS_FAIL) @@ -189,7 +207,3 @@ pub fn drop_checked<T: ZludaObject>(handle: T::CudaHandle) -> Result<(), CUerror unsafe { ManuallyDrop::drop(&mut wrapped_object) }; underlying_error } - -pub(crate) fn init(flags: ::core::ffi::c_uint) -> hipError_t { - unsafe { hipInit(flags) } -} diff --git a/zluda/src/impl/pointer.rs b/zluda/src/impl/pointer.rs index 1eef540..6b458a0 100644 --- a/zluda/src/impl/pointer.rs +++ b/zluda/src/impl/pointer.rs @@ -6,28 +6,27 @@ pub(crate) unsafe fn get_attribute( data: *mut c_void, attribute: hipPointer_attribute, ptr: hipDeviceptr_t, -) -> CUresult { +) -> hipError_t { if data == ptr::null_mut() { - return CUresult::ERROR_INVALID_VALUE; + return hipError_t::ErrorInvalidValue; } - // TODO: implement by getting device ordinal & allocation start, - // then go through every context for that device - if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT { - return CUresult::ERROR_NOT_SUPPORTED; + match attribute { + // TODO: implement by getting device ordinal & allocation start, + // then go through every context for that device + hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT => hipError_t::ErrorNotSupported, + hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE => { + let mut hip_result = hipMemoryType(0); + hipPointerGetAttribute( + (&mut hip_result as *mut hipMemoryType).cast::<c_void>(), + attribute, + ptr, + )?; + let cuda_result = memory_type(hip_result)?; + unsafe { *(data.cast()) = cuda_result }; + Ok(()) + } + _ => unsafe { hipPointerGetAttribute(data, attribute, ptr) }, } - if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE { - let mut hip_result = hipMemoryType(0); - hipPointerGetAttribute( - (&mut hip_result as *mut hipMemoryType).cast::<c_void>(), - attribute, - ptr, - )?; - let cuda_result = memory_type(hip_result)?; - *(data as _) = cuda_result; - } else { - hipPointerGetAttribute(data, attribute, ptr)?; - } - Ok(()) } fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> { @@ -36,6 +35,6 @@ fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> { hipMemoryType::hipMemoryTypeDevice => Ok(CUmemorytype::CU_MEMORYTYPE_DEVICE), hipMemoryType::hipMemoryTypeArray => Ok(CUmemorytype::CU_MEMORYTYPE_ARRAY), hipMemoryType::hipMemoryTypeUnified => Ok(CUmemorytype::CU_MEMORYTYPE_UNIFIED), - _ => Err(hipErrorCode_t::hipErrorInvalidValue), + _ => Err(hipErrorCode_t::InvalidValue), } } diff --git a/zluda/src/impl/test.rs b/zluda/src/impl/test.rs deleted file mode 100644 index b36ccd8..0000000 --- a/zluda/src/impl/test.rs +++ /dev/null @@ -1,157 +0,0 @@ -#![allow(non_snake_case)] - -use crate::cuda as zluda; -use crate::cuda::CUstream; -use crate::cuda::CUuuid; -use crate::{ - cuda::{CUdevice, CUdeviceptr}, - r#impl::CUresult, -}; -use ::std::{ - ffi::c_void, - os::raw::{c_int, c_uint}, -}; -use cuda_driver_sys as cuda; - -#[macro_export] -macro_rules! cuda_driver_test { - ($func:ident) => { - paste! { - #[test] - fn [<$func _zluda>]() { - $func::<crate::r#impl::test::Zluda>() - } - - #[test] - fn [<$func _cuda>]() { - $func::<crate::r#impl::test::Cuda>() - } - } - }; -} - -pub trait CudaDriverFns { - fn cuInit(flags: c_uint) -> CUresult; - fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult; - fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult; - fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult; - fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult; - fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult; - fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult; - fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult; - fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult; - fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult; - fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult; - fn cuMemFree_v2(mem: *mut c_void) -> CUresult; - fn cuStreamDestroy_v2(stream: CUstream) -> CUresult; -} - -pub struct Zluda(); - -impl CudaDriverFns for Zluda { - fn cuInit(_flags: c_uint) -> CUresult { - zluda::cuInit(_flags as _) - } - - fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult { - zluda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev)) - } - - fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult { - zluda::cuCtxDestroy_v2(ctx as *mut _) - } - - fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult { - zluda::cuCtxPopCurrent_v2(pctx as *mut _) - } - - fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult { - zluda::cuCtxGetApiVersion(ctx as *mut _, version) - } - - fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult { - zluda::cuCtxGetCurrent(pctx as *mut _) - } - fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult { - zluda::cuMemAlloc_v2(dptr as *mut _, bytesize) - } - - fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult { - zluda::cuDeviceGetUuid(uuid, CUdevice(dev)) - } - - fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult { - zluda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active) - } - - fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { - zluda::cuStreamGetCtx(hStream, pctx as _) - } - - fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult { - zluda::cuStreamCreate(stream, flags) - } - - fn cuMemFree_v2(dptr: *mut c_void) -> CUresult { - zluda::cuMemFree_v2(CUdeviceptr(dptr as _)) - } - - fn cuStreamDestroy_v2(stream: CUstream) -> CUresult { - zluda::cuStreamDestroy_v2(stream) - } -} - -pub struct Cuda(); - -impl CudaDriverFns for Cuda { - fn cuInit(flags: c_uint) -> CUresult { - unsafe { CUresult(cuda::cuInit(flags) as c_uint) } - } - - fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult { - unsafe { CUresult(cuda::cuCtxCreate_v2(pctx as *mut _, flags, dev) as c_uint) } - } - - fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult { - unsafe { CUresult(cuda::cuCtxDestroy_v2(ctx as *mut _) as c_uint) } - } - - fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult { - unsafe { CUresult(cuda::cuCtxPopCurrent_v2(pctx as *mut _) as c_uint) } - } - - fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult { - unsafe { CUresult(cuda::cuCtxGetApiVersion(ctx as *mut _, version) as c_uint) } - } - - fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult { - unsafe { CUresult(cuda::cuCtxGetCurrent(pctx as *mut _) as c_uint) } - } - fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult { - unsafe { CUresult(cuda::cuMemAlloc_v2(dptr as *mut _, bytesize) as c_uint) } - } - - fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult { - unsafe { CUresult(cuda::cuDeviceGetUuid(uuid as *mut _, dev) as c_uint) } - } - - fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult { - unsafe { CUresult(cuda::cuDevicePrimaryCtxGetState(dev, flags, active) as c_uint) } - } - - fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { - unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) } - } - - fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult { - unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) } - } - - fn cuMemFree_v2(mem: *mut c_void) -> CUresult { - unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) } - } - - fn cuStreamDestroy_v2(stream: CUstream) -> CUresult { - unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) } - } -} diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index 12d6ce0..bda67e1 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -27,10 +27,25 @@ macro_rules! implemented { }; } +macro_rules! implemented_in_function { + ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path;)*) => { + $( + #[cfg_attr(not(test), no_mangle)] + #[allow(improper_ctypes)] + #[allow(improper_ctypes_definitions)] + pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type { + cuda_base::cuda_normalize_fn!( crate::r#impl::function::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda(&$arg_id)?),*)?; + Ok(()) + } + )* + }; +} + cuda_base::cuda_function_declarations!( unimplemented, implemented <= [ cuCtxGetLimit, + cuCtxSetCurrent, cuCtxSetLimit, cuCtxSynchronize, cuDeviceComputeCapability, @@ -39,13 +54,25 @@ cuda_base::cuda_function_declarations!( cuDeviceGetCount, cuDeviceGetLuid, cuDeviceGetName, + cuDevicePrimaryCtxRelease, + cuDevicePrimaryCtxRetain, cuDeviceGetProperties, cuDeviceGetUuid, cuDeviceGetUuid_v2, cuDeviceTotalMem_v2, + cuDriverGetVersion, + cuFuncGetAttribute, cuInit, + cuMemAlloc_v2, + cuMemFree_v2, + cuMemcpyDtoH_v2, + cuMemcpyHtoD_v2, cuModuleGetFunction, cuModuleLoadData, cuModuleUnload, + cuPointerGetAttribute, + ], + implemented_in_function <= [ + cuLaunchKernel, ] );
\ No newline at end of file |