diff options
Diffstat (limited to 'zluda/src/impl/pointer.rs')
-rw-r--r-- | zluda/src/impl/pointer.rs | 39 |
1 files changed, 19 insertions, 20 deletions
diff --git a/zluda/src/impl/pointer.rs b/zluda/src/impl/pointer.rs index 1eef540..6b458a0 100644 --- a/zluda/src/impl/pointer.rs +++ b/zluda/src/impl/pointer.rs @@ -6,28 +6,27 @@ pub(crate) unsafe fn get_attribute( data: *mut c_void, attribute: hipPointer_attribute, ptr: hipDeviceptr_t, -) -> CUresult { +) -> hipError_t { if data == ptr::null_mut() { - return CUresult::ERROR_INVALID_VALUE; + return hipError_t::ErrorInvalidValue; } - // TODO: implement by getting device ordinal & allocation start, - // then go through every context for that device - if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT { - return CUresult::ERROR_NOT_SUPPORTED; + match attribute { + // TODO: implement by getting device ordinal & allocation start, + // then go through every context for that device + hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT => hipError_t::ErrorNotSupported, + hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE => { + let mut hip_result = hipMemoryType(0); + hipPointerGetAttribute( + (&mut hip_result as *mut hipMemoryType).cast::<c_void>(), + attribute, + ptr, + )?; + let cuda_result = memory_type(hip_result)?; + unsafe { *(data.cast()) = cuda_result }; + Ok(()) + } + _ => unsafe { hipPointerGetAttribute(data, attribute, ptr) }, } - if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE { - let mut hip_result = hipMemoryType(0); - hipPointerGetAttribute( - (&mut hip_result as *mut hipMemoryType).cast::<c_void>(), - attribute, - ptr, - )?; - let cuda_result = memory_type(hip_result)?; - *(data as _) = cuda_result; - } else { - hipPointerGetAttribute(data, attribute, ptr)?; - } - Ok(()) } fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> { @@ -36,6 +35,6 @@ fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> { hipMemoryType::hipMemoryTypeDevice => Ok(CUmemorytype::CU_MEMORYTYPE_DEVICE), hipMemoryType::hipMemoryTypeArray => Ok(CUmemorytype::CU_MEMORYTYPE_ARRAY), hipMemoryType::hipMemoryTypeUnified => Ok(CUmemorytype::CU_MEMORYTYPE_UNIFIED), - _ => Err(hipErrorCode_t::hipErrorInvalidValue), + _ => Err(hipErrorCode_t::InvalidValue), } } |