diff options
Diffstat (limited to 'zluda/src/impl')
-rw-r--r-- | zluda/src/impl/pointer.rs | 64 |
1 files changed, 26 insertions, 38 deletions
diff --git a/zluda/src/impl/pointer.rs b/zluda/src/impl/pointer.rs index 2b925cd..1eef540 100644 --- a/zluda/src/impl/pointer.rs +++ b/zluda/src/impl/pointer.rs @@ -1,53 +1,41 @@ -use std::{ffi::c_void, mem, ptr}; - -use hip_runtime_sys::{hipError_t, hipMemoryType, hipPointerGetAttributes}; - -use crate::{ - cuda::{CUdeviceptr, CUmemorytype, CUpointer_attribute}, - hip_call, -}; +use cuda_types::*; +use hip_runtime_sys::*; +use std::{ffi::c_void, ptr}; pub(crate) unsafe fn get_attribute( data: *mut c_void, - attribute: CUpointer_attribute, - ptr: CUdeviceptr, -) -> Result<(), hipError_t> { + attribute: hipPointer_attribute, + ptr: hipDeviceptr_t, +) -> CUresult { if data == ptr::null_mut() { - return Err(hipError_t::hipErrorInvalidValue); + return CUresult::ERROR_INVALID_VALUE; + } + // TODO: implement by getting device ordinal & allocation start, + // then go through every context for that device + if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT { + return CUresult::ERROR_NOT_SUPPORTED; } - let mut attribs = mem::zeroed(); - hip_call! { hipPointerGetAttributes(&mut attribs, ptr.0 as _) }; - match attribute { - CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT => { - *(data as *mut _) = attribs.device; - Ok(()) - } - CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE => { - *(data as *mut _) = memory_type(attribs.memoryType)?; - Ok(()) - } - CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER => { - *(data as *mut _) = attribs.devicePointer; - Ok(()) - } - CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER => { - *(data as *mut _) = attribs.hostPointer; - Ok(()) - } - CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED => { - *(data as *mut _) = attribs.isManaged; - Ok(()) - } - _ => Err(hipError_t::hipErrorNotSupported), + if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE { + let mut hip_result = hipMemoryType(0); + hipPointerGetAttribute( + (&mut hip_result as *mut hipMemoryType).cast::<c_void>(), + attribute, + ptr, + )?; + let cuda_result = memory_type(hip_result)?; + *(data as _) = cuda_result; + } else { + hipPointerGetAttribute(data, attribute, ptr)?; } + Ok(()) } -pub(crate) fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipError_t> { +fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> { match cu { hipMemoryType::hipMemoryTypeHost => Ok(CUmemorytype::CU_MEMORYTYPE_HOST), hipMemoryType::hipMemoryTypeDevice => Ok(CUmemorytype::CU_MEMORYTYPE_DEVICE), hipMemoryType::hipMemoryTypeArray => Ok(CUmemorytype::CU_MEMORYTYPE_ARRAY), hipMemoryType::hipMemoryTypeUnified => Ok(CUmemorytype::CU_MEMORYTYPE_UNIFIED), - _ => Err(hipError_t::hipErrorInvalidValue), + _ => Err(hipErrorCode_t::hipErrorInvalidValue), } } |