aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src/impl')
-rw-r--r--zluda/src/impl/pointer.rs64
1 files changed, 26 insertions, 38 deletions
diff --git a/zluda/src/impl/pointer.rs b/zluda/src/impl/pointer.rs
index 2b925cd..1eef540 100644
--- a/zluda/src/impl/pointer.rs
+++ b/zluda/src/impl/pointer.rs
@@ -1,53 +1,41 @@
-use std::{ffi::c_void, mem, ptr};
-
-use hip_runtime_sys::{hipError_t, hipMemoryType, hipPointerGetAttributes};
-
-use crate::{
- cuda::{CUdeviceptr, CUmemorytype, CUpointer_attribute},
- hip_call,
-};
+use cuda_types::*;
+use hip_runtime_sys::*;
+use std::{ffi::c_void, ptr};
pub(crate) unsafe fn get_attribute(
data: *mut c_void,
- attribute: CUpointer_attribute,
- ptr: CUdeviceptr,
-) -> Result<(), hipError_t> {
+ attribute: hipPointer_attribute,
+ ptr: hipDeviceptr_t,
+) -> CUresult {
if data == ptr::null_mut() {
- return Err(hipError_t::hipErrorInvalidValue);
+ return CUresult::ERROR_INVALID_VALUE;
+ }
+ // TODO: implement by getting device ordinal & allocation start,
+ // then go through every context for that device
+ if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT {
+ return CUresult::ERROR_NOT_SUPPORTED;
}
- let mut attribs = mem::zeroed();
- hip_call! { hipPointerGetAttributes(&mut attribs, ptr.0 as _) };
- match attribute {
- CUpointer_attribute::CU_POINTER_ATTRIBUTE_CONTEXT => {
- *(data as *mut _) = attribs.device;
- Ok(())
- }
- CUpointer_attribute::CU_POINTER_ATTRIBUTE_MEMORY_TYPE => {
- *(data as *mut _) = memory_type(attribs.memoryType)?;
- Ok(())
- }
- CUpointer_attribute::CU_POINTER_ATTRIBUTE_DEVICE_POINTER => {
- *(data as *mut _) = attribs.devicePointer;
- Ok(())
- }
- CUpointer_attribute::CU_POINTER_ATTRIBUTE_HOST_POINTER => {
- *(data as *mut _) = attribs.hostPointer;
- Ok(())
- }
- CUpointer_attribute::CU_POINTER_ATTRIBUTE_IS_MANAGED => {
- *(data as *mut _) = attribs.isManaged;
- Ok(())
- }
- _ => Err(hipError_t::hipErrorNotSupported),
+ if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE {
+ let mut hip_result = hipMemoryType(0);
+ hipPointerGetAttribute(
+ (&mut hip_result as *mut hipMemoryType).cast::<c_void>(),
+ attribute,
+ ptr,
+ )?;
+ let cuda_result = memory_type(hip_result)?;
+ *(data as _) = cuda_result;
+ } else {
+ hipPointerGetAttribute(data, attribute, ptr)?;
}
+ Ok(())
}
-pub(crate) fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipError_t> {
+fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
match cu {
hipMemoryType::hipMemoryTypeHost => Ok(CUmemorytype::CU_MEMORYTYPE_HOST),
hipMemoryType::hipMemoryTypeDevice => Ok(CUmemorytype::CU_MEMORYTYPE_DEVICE),
hipMemoryType::hipMemoryTypeArray => Ok(CUmemorytype::CU_MEMORYTYPE_ARRAY),
hipMemoryType::hipMemoryTypeUnified => Ok(CUmemorytype::CU_MEMORYTYPE_UNIFIED),
- _ => Err(hipError_t::hipErrorInvalidValue),
+ _ => Err(hipErrorCode_t::hipErrorInvalidValue),
}
}