aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/pointer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src/impl/pointer.rs')
-rw-r--r--zluda/src/impl/pointer.rs39
1 files changed, 19 insertions, 20 deletions
diff --git a/zluda/src/impl/pointer.rs b/zluda/src/impl/pointer.rs
index 1eef540..6b458a0 100644
--- a/zluda/src/impl/pointer.rs
+++ b/zluda/src/impl/pointer.rs
@@ -6,28 +6,27 @@ pub(crate) unsafe fn get_attribute(
data: *mut c_void,
attribute: hipPointer_attribute,
ptr: hipDeviceptr_t,
-) -> CUresult {
+) -> hipError_t {
if data == ptr::null_mut() {
- return CUresult::ERROR_INVALID_VALUE;
+ return hipError_t::ErrorInvalidValue;
}
- // TODO: implement by getting device ordinal & allocation start,
- // then go through every context for that device
- if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT {
- return CUresult::ERROR_NOT_SUPPORTED;
+ match attribute {
+ // TODO: implement by getting device ordinal & allocation start,
+ // then go through every context for that device
+ hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT => hipError_t::ErrorNotSupported,
+ hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE => {
+ let mut hip_result = hipMemoryType(0);
+ hipPointerGetAttribute(
+ (&mut hip_result as *mut hipMemoryType).cast::<c_void>(),
+ attribute,
+ ptr,
+ )?;
+ let cuda_result = memory_type(hip_result)?;
+ unsafe { *(data.cast()) = cuda_result };
+ Ok(())
+ }
+ _ => unsafe { hipPointerGetAttribute(data, attribute, ptr) },
}
- if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE {
- let mut hip_result = hipMemoryType(0);
- hipPointerGetAttribute(
- (&mut hip_result as *mut hipMemoryType).cast::<c_void>(),
- attribute,
- ptr,
- )?;
- let cuda_result = memory_type(hip_result)?;
- *(data as _) = cuda_result;
- } else {
- hipPointerGetAttribute(data, attribute, ptr)?;
- }
- Ok(())
}
fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
@@ -36,6 +35,6 @@ fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
hipMemoryType::hipMemoryTypeDevice => Ok(CUmemorytype::CU_MEMORYTYPE_DEVICE),
hipMemoryType::hipMemoryTypeArray => Ok(CUmemorytype::CU_MEMORYTYPE_ARRAY),
hipMemoryType::hipMemoryTypeUnified => Ok(CUmemorytype::CU_MEMORYTYPE_UNIFIED),
- _ => Err(hipErrorCode_t::hipErrorInvalidValue),
+ _ => Err(hipErrorCode_t::InvalidValue),
}
}