aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--cuda_base/src/lib.rs38
-rw-r--r--zluda/Cargo.toml1
-rw-r--r--zluda/src/impl/context.rs80
-rw-r--r--zluda/src/impl/device.rs30
-rw-r--r--zluda/src/impl/driver.rs79
-rw-r--r--zluda/src/impl/function.rs62
-rw-r--r--zluda/src/impl/link.rs86
-rw-r--r--zluda/src/impl/memory.rs70
-rw-r--r--zluda/src/impl/mod.rs48
-rw-r--r--zluda/src/impl/pointer.rs39
-rw-r--r--zluda/src/impl/test.rs157
-rw-r--r--zluda/src/lib.rs27
12 files changed, 347 insertions, 370 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index 0cc1f53..833d372 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -150,6 +150,10 @@ impl VisitMut for FixFnSignatures {
}
}
+const MODULES: &[&str] = &[
+ "context", "device", "driver", "function", "link", "memory", "module", "pointer",
+];
+
#[proc_macro]
pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
let mut path = parse_macro_input!(tokens as syn::Path);
@@ -161,8 +165,9 @@ pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
.0
.ident
.to_string();
+ let already_has_module = MODULES.contains(&&*path.segments.last().unwrap().ident.to_string());
let segments: Vec<String> = split(&fn_[2..]); // skip "cu"
- let fn_path = join(segments);
+ let fn_path = join(segments, !already_has_module);
quote! {
#path #fn_path
}
@@ -181,23 +186,16 @@ fn split(fn_: &str) -> Vec<String> {
result
}
-fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> {
+fn join(fn_: Vec<String>, find_module: bool) -> Punctuated<Ident, Token![::]> {
fn full_form(segment: &str) -> Option<&[&str]> {
Some(match segment {
"ctx" => &["context"],
+ "func" => &["function"],
+ "mem" => &["memory"],
"memcpy" => &["memory", "copy"],
_ => return None,
})
}
- const MODULES: &[&str] = &[
- "context",
- "device",
- "function",
- "link",
- "memory",
- "module",
- "pointer"
- ];
let mut normalized: Vec<&str> = Vec::new();
for segment in fn_.iter() {
match full_form(segment) {
@@ -205,18 +203,20 @@ fn join(fn_: Vec<String>) -> Punctuated<Ident, Token![::]> {
None => normalized.push(&*segment),
}
}
+ if !find_module {
+ return [Ident::new(&normalized.join("_"), Span::call_site())]
+ .into_iter()
+ .collect();
+ }
if !MODULES.contains(&normalized[0]) {
- let mut globalized = vec!["global"];
+ let mut globalized = vec!["driver"];
globalized.extend(normalized);
normalized = globalized;
}
let (module, path) = normalized.split_first().unwrap();
let path = path.join("_");
- let mut result = Punctuated::new();
- result.extend(
- [module, &&*path]
- .into_iter()
- .map(|s| Ident::new(s, Span::call_site())),
- );
- result
+ [module, &&*path]
+ .into_iter()
+ .map(|s| Ident::new(s, Span::call_site()))
+ .collect()
}
diff --git a/zluda/Cargo.toml b/zluda/Cargo.toml
index 0a4c406..14e9845 100644
--- a/zluda/Cargo.toml
+++ b/zluda/Cargo.toml
@@ -20,6 +20,7 @@ num_enum = "0.4"
lz4-sys = "1.9"
tempfile = "3"
paste = "1.0"
+rustc-hash = "1.1"
[target.'cfg(windows)'.dependencies]
winapi = { version = "0.3", features = ["heapapi", "std"] }
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs
index d1a135f..973febc 100644
--- a/zluda/src/impl/context.rs
+++ b/zluda/src/impl/context.rs
@@ -1,4 +1,46 @@
+use super::{driver, FromCuda, ZludaObject};
+use cuda_types::*;
use hip_runtime_sys::*;
+use rustc_hash::FxHashSet;
+use std::{cell::RefCell, ptr, sync::Mutex};
+
+thread_local! {
+ pub(crate) static CONTEXT_STACK: RefCell<Vec<(CUcontext, hipDevice_t)>> = RefCell::new(Vec::new());
+}
+
+pub(crate) struct Context {
+ pub(crate) device: hipDevice_t,
+ pub(crate) mutable: Mutex<OwnedByContext>,
+}
+
+pub(crate) struct OwnedByContext {
+ pub(crate) ref_count: usize, // only used by primary context
+ pub(crate) _memory: FxHashSet<hipDeviceptr_t>,
+ pub(crate) _streams: FxHashSet<hipStream_t>,
+ pub(crate) _modules: FxHashSet<CUmodule>,
+}
+
+impl ZludaObject for Context {
+ const COOKIE: usize = 0x5f867c6d9cb73315;
+
+ type CudaHandle = CUcontext;
+
+ fn drop_checked(&mut self) -> CUresult {
+ Ok(())
+ }
+}
+
+pub(crate) fn new(device: hipDevice_t) -> Context {
+ Context {
+ device,
+ mutable: Mutex::new(OwnedByContext {
+ ref_count: 0,
+ _memory: FxHashSet::default(),
+ _streams: FxHashSet::default(),
+ _modules: FxHashSet::default(),
+ }),
+ }
+}
pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: hipLimit_t) -> hipError_t {
unsafe { hipDeviceGetLimit(pvalue, limit) }
@@ -11,3 +53,41 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> hipError_t {
pub(crate) fn synchronize() -> hipError_t {
unsafe { hipDeviceSynchronize() }
}
+
+pub(crate) fn get_primary(hip_dev: hipDevice_t) -> Result<(&'static Context, CUcontext), CUerror> {
+ let dev = driver::device(hip_dev)?;
+ Ok(dev.primary_context())
+}
+
+pub(crate) fn set_current(raw_ctx: CUcontext) -> CUresult {
+ let new_device = if raw_ctx.0 == ptr::null_mut() {
+ CONTEXT_STACK.with(|stack| {
+ let mut stack = stack.borrow_mut();
+ if let Some((_, old_device)) = stack.pop() {
+ if let Some((_, new_device)) = stack.last() {
+ if old_device != *new_device {
+ return Some(*new_device);
+ }
+ }
+ }
+ None
+ })
+ } else {
+ let ctx: &Context = FromCuda::from_cuda(&raw_ctx)?;
+ let device = ctx.device;
+ CONTEXT_STACK.with(move |stack| {
+ let mut stack = stack.borrow_mut();
+ let last_device = stack.last().map(|(_, dev)| *dev);
+ stack.push((raw_ctx, device));
+ match last_device {
+ None => Some(device),
+ Some(last_device) if last_device != device => Some(device),
+ _ => None,
+ }
+ })
+ };
+ if let Some(dev) = new_device {
+ unsafe { hipSetDevice(dev)? };
+ }
+ Ok(())
+}
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index a2a56c9..8836c1e 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -2,6 +2,8 @@ use cuda_types::*;
use hip_runtime_sys::*;
use std::{mem, ptr};
+use super::context;
+
const PROJECT_SUFFIX: &[u8] = b" [ZLUDA]\0";
pub const COMPUTE_CAPABILITY_MAJOR: i32 = 8;
pub const COMPUTE_CAPABILITY_MINOR: i32 = 8;
@@ -307,3 +309,31 @@ pub(crate) fn get_count(count: &mut ::core::ffi::c_int) -> hipError_t {
fn clamp_usize(x: usize) -> i32 {
usize::min(x, i32::MAX as usize) as i32
}
+
+pub(crate) fn primary_context_retain(
+ pctx: &mut CUcontext,
+ hip_dev: hipDevice_t,
+) -> Result<(), CUerror> {
+ let (ctx, raw_ctx) = context::get_primary(hip_dev)?;
+ {
+ let mut mutable_ctx = ctx.mutable.lock().map_err(|_| CUerror::UNKNOWN)?;
+ mutable_ctx.ref_count += 1;
+ }
+ *pctx = raw_ctx;
+ Ok(())
+}
+
+pub(crate) fn primary_context_release(hip_dev: hipDevice_t) -> Result<(), CUerror> {
+ let (ctx, _) = context::get_primary(hip_dev)?;
+ {
+ let mut mutable_ctx = ctx.mutable.lock().map_err(|_| CUerror::UNKNOWN)?;
+ if mutable_ctx.ref_count == 0 {
+ return Err(CUerror::INVALID_CONTEXT);
+ }
+ mutable_ctx.ref_count -= 1;
+ if mutable_ctx.ref_count == 0 {
+ // TODO: drop all children
+ }
+ }
+ Ok(())
+}
diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs
new file mode 100644
index 0000000..7ff2f54
--- /dev/null
+++ b/zluda/src/impl/driver.rs
@@ -0,0 +1,79 @@
+use cuda_types::*;
+use hip_runtime_sys::*;
+use std::{
+ ffi::{CStr, CString},
+ mem, slice,
+ sync::OnceLock,
+};
+
+use crate::r#impl::context;
+
+use super::LiveCheck;
+
+pub(crate) struct GlobalState {
+ pub devices: Vec<Device>,
+}
+
+pub(crate) struct Device {
+ pub(crate) _comgr_isa: CString,
+ primary_context: LiveCheck<context::Context>,
+}
+
+impl Device {
+ pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
+ unsafe {
+ (
+ self.primary_context.data.assume_init_ref(),
+ self.primary_context.as_handle(),
+ )
+ }
+ }
+}
+
+pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
+ global_state()?
+ .devices
+ .get(dev as usize)
+ .ok_or(CUerror::INVALID_DEVICE)
+}
+
+pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
+ static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
+ fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
+ unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
+ }
+ GLOBAL_STATE
+ .get_or_init(|| {
+ let mut device_count = 0;
+ unsafe { hipGetDeviceCount(&mut device_count) }?;
+ Ok(GlobalState {
+ devices: (0..device_count)
+ .map(|i| {
+ let mut props = unsafe { mem::zeroed() };
+ unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
+ Ok::<_, CUerror>(Device {
+ _comgr_isa: CStr::from_bytes_until_nul(cast_slice(
+ &props.gcnArchName[..],
+ ))
+ .map_err(|_| CUerror::UNKNOWN)?
+ .to_owned(),
+ primary_context: LiveCheck::new(context::new(i)),
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?,
+ })
+ })
+ .as_ref()
+ .map_err(|e| *e)
+}
+
+pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
+ unsafe { hipInit(flags) }?;
+ global_state()?;
+ Ok(())
+}
+
+pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
+ *version = cuda_types::CUDA_VERSION as i32;
+ Ok(())
+}
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs
index 7f35bb4..8d006ec 100644
--- a/zluda/src/impl/function.rs
+++ b/zluda/src/impl/function.rs
@@ -1,26 +1,46 @@
-use hip_runtime_sys::{hipError_t, hipFuncAttribute, hipFuncGetAttribute, hipFuncGetAttributes, hipFunction_attribute, hipLaunchKernel, hipModuleLaunchKernel};
-
-use super::{CUresult, HasLivenessCookie, LiveCheck};
-use crate::cuda::{CUfunction, CUfunction_attribute, CUstream};
-use ::std::os::raw::{c_uint, c_void};
-use std::{mem, ptr};
+use hip_runtime_sys::*;
pub(crate) fn get_attribute(
- pi: *mut i32,
- cu_attrib: CUfunction_attribute,
- func: CUfunction,
+ pi: &mut i32,
+ cu_attrib: hipFunction_attribute,
+ func: hipFunction_t,
+) -> hipError_t {
+ // TODO: implement HIP_FUNC_ATTRIBUTE_PTX_VERSION
+ // TODO: implement HIP_FUNC_ATTRIBUTE_BINARY_VERSION
+ unsafe { hipFuncGetAttribute(pi, cu_attrib, func) }?;
+ if cu_attrib == hipFunction_attribute::HIP_FUNC_ATTRIBUTE_NUM_REGS {
+ *pi = (*pi).max(1);
+ }
+ Ok(())
+}
+
+pub(crate) fn launch_kernel(
+ f: hipFunction_t,
+ grid_dim_x: ::core::ffi::c_uint,
+ grid_dim_y: ::core::ffi::c_uint,
+ grid_dim_z: ::core::ffi::c_uint,
+ block_dim_x: ::core::ffi::c_uint,
+ block_dim_y: ::core::ffi::c_uint,
+ block_dim_z: ::core::ffi::c_uint,
+ shared_mem_bytes: ::core::ffi::c_uint,
+ stream: hipStream_t,
+ kernel_params: *mut *mut ::core::ffi::c_void,
+ extra: *mut *mut ::core::ffi::c_void,
) -> hipError_t {
- if pi == ptr::null_mut() || func == ptr::null_mut() {
- return hipError_t::hipErrorInvalidValue;
+ // TODO: fix constants in extra
+ unsafe {
+ hipModuleLaunchKernel(
+ f,
+ grid_dim_x,
+ grid_dim_y,
+ grid_dim_z,
+ block_dim_x,
+ block_dim_y,
+ block_dim_z,
+ shared_mem_bytes,
+ stream,
+ kernel_params,
+ extra,
+ )
}
- let attrib = match cu_attrib {
- CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
- hipFunction_attribute::HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK
- }
- CUfunction_attribute::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES => {
- hipFunction_attribute::HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES
- }
- _ => return hipError_t::hipErrorInvalidValue,
- };
- unsafe { hipFuncGetAttribute(pi, attrib, func as _) }
}
diff --git a/zluda/src/impl/link.rs b/zluda/src/impl/link.rs
deleted file mode 100644
index d66608f..0000000
--- a/zluda/src/impl/link.rs
+++ /dev/null
@@ -1,86 +0,0 @@
-use std::{
- ffi::{c_void, CStr},
- mem, ptr, slice,
-};
-
-use hip_runtime_sys::{hipCtxGetDevice, hipError_t, hipGetDeviceProperties};
-
-use crate::{
- cuda::{CUjitInputType, CUjit_option, CUlinkState, CUresult},
- hip_call,
-};
-
-use super::module::{self, SpirvModule};
-
-struct LinkState {
- modules: Vec<SpirvModule>,
- result: Option<Vec<u8>>,
-}
-
-pub(crate) unsafe fn create(
- num_options: u32,
- options: *mut CUjit_option,
- option_values: *mut *mut c_void,
- state_out: *mut CUlinkState,
-) -> CUresult {
- if state_out == ptr::null_mut() {
- return CUresult::CUDA_ERROR_INVALID_VALUE;
- }
- let state = Box::new(LinkState {
- modules: Vec::new(),
- result: None,
- });
- *state_out = mem::transmute(state);
- CUresult::CUDA_SUCCESS
-}
-
-pub(crate) unsafe fn add_data(
- state: CUlinkState,
- type_: CUjitInputType,
- data: *mut c_void,
- size: usize,
- name: *const i8,
- num_options: u32,
- options: *mut CUjit_option,
- option_values: *mut *mut c_void,
-) -> Result<(), hipError_t> {
- if state == ptr::null_mut() {
- return Err(hipError_t::hipErrorInvalidValue);
- }
- let state: *mut LinkState = mem::transmute(state);
- let state = &mut *state;
- // V-RAY specific hack
- if state.modules.len() == 2 {
- return Err(hipError_t::hipSuccess);
- }
- let spirv_data = SpirvModule::new_raw(data as *const _)?;
- state.modules.push(spirv_data);
- Ok(())
-}
-
-pub(crate) unsafe fn complete(
- state: CUlinkState,
- cubin_out: *mut *mut c_void,
- size_out: *mut usize,
-) -> Result<(), hipError_t> {
- let mut dev = 0;
- hip_call! { hipCtxGetDevice(&mut dev) };
- let mut props = unsafe { mem::zeroed() };
- hip_call! { hipGetDeviceProperties(&mut props, dev) };
- let state: &mut LinkState = mem::transmute(state);
- let spirv_bins = state.modules.iter().map(|m| &m.binaries[..]);
- let should_link_ptx_impl = state.modules.iter().find_map(|m| m.should_link_ptx_impl);
- let mut arch_binary = module::compile_amd(&props, spirv_bins, should_link_ptx_impl)
- .map_err(|_| hipError_t::hipErrorUnknown)?;
- let ptr = arch_binary.as_mut_ptr();
- let size = arch_binary.len();
- state.result = Some(arch_binary);
- *cubin_out = ptr as _;
- *size_out = size;
- Ok(())
-}
-
-pub(crate) unsafe fn destroy(state: CUlinkState) -> CUresult {
- let state: Box<LinkState> = mem::transmute(state);
- CUresult::CUDA_SUCCESS
-}
diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs
index 6041623..b23afa9 100644
--- a/zluda/src/impl/memory.rs
+++ b/zluda/src/impl/memory.rs
@@ -1,55 +1,25 @@
-use hip_runtime_sys::{
- hipDrvMemcpy3D, hipError_t, hipMemcpy3D, hipMemcpy3DParms, hipMemoryType, hipPitchedPtr,
- hipPos, HIP_MEMCPY3D,
-};
-use std::ptr;
+use hip_runtime_sys::*;
-use crate::{
- cuda::{CUDA_MEMCPY3D_st, CUdeviceptr, CUmemorytype, CUresult},
- hip_call,
-};
+pub(crate) fn alloc_v2(dptr: *mut hipDeviceptr_t, bytesize: usize) -> hipError_t {
+ unsafe { hipMalloc(dptr.cast(), bytesize) }
+}
+
+pub(crate) fn free_v2(dptr: hipDeviceptr_t) -> hipError_t {
+ unsafe { hipFree(dptr.0) }
+}
-// TODO change HIP impl to 64 bits
-pub(crate) unsafe fn copy_3d(cu_copy: *const CUDA_MEMCPY3D_st) -> Result<(), hipError_t> {
- if cu_copy == ptr::null() {
- return Err(hipError_t::hipErrorInvalidValue);
- }
- let cu_copy = *cu_copy;
- let hip_copy = HIP_MEMCPY3D {
- srcXInBytes: cu_copy.srcXInBytes as u32,
- srcY: cu_copy.srcY as u32,
- srcZ: cu_copy.srcZ as u32,
- srcLOD: cu_copy.srcLOD as u32,
- srcMemoryType: memory_type(cu_copy.srcMemoryType)?,
- srcHost: cu_copy.srcHost,
- srcDevice: cu_copy.srcDevice.0 as _,
- srcArray: cu_copy.srcArray as _,
- srcPitch: cu_copy.srcPitch as u32,
- srcHeight: cu_copy.srcHeight as u32,
- dstXInBytes: cu_copy.dstXInBytes as u32,
- dstY: cu_copy.dstY as u32,
- dstZ: cu_copy.dstZ as u32,
- dstLOD: cu_copy.dstLOD as u32,
- dstMemoryType: memory_type(cu_copy.dstMemoryType)?,
- dstHost: cu_copy.dstHost,
- dstDevice: cu_copy.dstDevice.0 as _,
- dstArray: cu_copy.dstArray as _,
- dstPitch: cu_copy.dstPitch as u32,
- dstHeight: cu_copy.dstHeight as u32,
- WidthInBytes: cu_copy.WidthInBytes as u32,
- Height: cu_copy.Height as u32,
- Depth: cu_copy.Depth as u32,
- };
- hip_call! { hipDrvMemcpy3D(&hip_copy) };
- Ok(())
+pub(crate) fn copy_dto_h_v2(
+ dst_host: *mut ::core::ffi::c_void,
+ src_device: hipDeviceptr_t,
+ byte_count: usize,
+) -> hipError_t {
+ unsafe { hipMemcpyDtoH(dst_host, src_device, byte_count) }
}
-pub(crate) fn memory_type(cu: CUmemorytype) -> Result<hipMemoryType, hipError_t> {
- match cu {
- CUmemorytype::CU_MEMORYTYPE_HOST => Ok(hipMemoryType::hipMemoryTypeHost),
- CUmemorytype::CU_MEMORYTYPE_DEVICE => Ok(hipMemoryType::hipMemoryTypeDevice),
- CUmemorytype::CU_MEMORYTYPE_ARRAY => Ok(hipMemoryType::hipMemoryTypeArray),
- CUmemorytype::CU_MEMORYTYPE_UNIFIED => Ok(hipMemoryType::hipMemoryTypeUnified),
- _ => Err(hipError_t::hipErrorInvalidValue),
- }
+pub(crate) fn copy_hto_d_v2(
+ dst_device: hipDeviceptr_t,
+ src_host: *const ::core::ffi::c_void,
+ byte_count: usize,
+) -> hipError_t {
+ unsafe { hipMemcpyHtoD(dst_device, src_host.cast_mut(), byte_count) }
}
diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs
index 0400006..7b4afc5 100644
--- a/zluda/src/impl/mod.rs
+++ b/zluda/src/impl/mod.rs
@@ -1,10 +1,14 @@
use cuda_types::*;
use hip_runtime_sys::*;
-use std::mem::{self, ManuallyDrop};
+use std::mem::{self, ManuallyDrop, MaybeUninit};
pub(super) mod context;
pub(super) mod device;
+pub(super) mod driver;
+pub(super) mod function;
+pub(super) mod memory;
pub(super) mod module;
+pub(super) mod pointer;
#[cfg(debug_assertions)]
pub(crate) fn unimplemented() -> CUresult {
@@ -97,9 +101,12 @@ macro_rules! from_cuda_object {
from_cuda_nop!(
*mut i8,
+ *mut i32,
*mut usize,
- *const std::ffi::c_void,
+ *const ::core::ffi::c_void,
*const ::core::ffi::c_char,
+ *mut ::core::ffi::c_void,
+ *mut *mut ::core::ffi::c_void,
i32,
u32,
usize,
@@ -107,11 +114,14 @@ from_cuda_nop!(
CUdevice_attribute
);
from_cuda_transmute!(
- CUdevice => hipDevice_t,
CUuuid => hipUUID,
- CUfunction => hipFunction_t
+ CUfunction => hipFunction_t,
+ CUfunction_attribute => hipFunction_attribute,
+ CUstream => hipStream_t,
+ CUpointer_attribute => hipPointer_attribute,
+ CUdeviceptr_v2 => hipDeviceptr_t
);
-from_cuda_object!(module::Module);
+from_cuda_object!(module::Module, context::Context);
impl<'a> FromCuda<'a, CUlimit> for hipLimit_t {
fn from_cuda(limit: &'a CUlimit) -> Result<Self, CUerror> {
@@ -140,20 +150,28 @@ pub(crate) trait ZludaObject: Sized + Send + Sync {
#[repr(C)]
pub(crate) struct LiveCheck<T: ZludaObject> {
cookie: usize,
- data: ManuallyDrop<T>,
+ data: MaybeUninit<T>,
}
impl<T: ZludaObject> LiveCheck<T> {
- fn wrap(data: T) -> *mut Self {
- Box::into_raw(Box::new(LiveCheck {
+ fn new(data: T) -> Self {
+ LiveCheck {
cookie: T::COOKIE,
- data: ManuallyDrop::new(data),
- }))
+ data: MaybeUninit::new(data),
+ }
+ }
+
+ fn as_handle(&self) -> T::CudaHandle {
+ unsafe { mem::transmute_copy(self) }
+ }
+
+ fn wrap(data: T) -> *mut Self {
+ Box::into_raw(Box::new(Self::new(data)))
}
fn as_result(&self) -> Result<&T, CUerror> {
if self.cookie == T::COOKIE {
- Ok(&self.data)
+ Ok(unsafe { self.data.assume_init_ref() })
} else {
Err(T::LIVENESS_FAIL)
}
@@ -167,8 +185,8 @@ impl<T: ZludaObject> LiveCheck<T> {
fn drop_checked(&mut self) -> Result<Result<(), CUerror>, CUerror> {
if self.cookie == T::COOKIE {
self.cookie = 0;
- let result = self.data.drop_checked();
- unsafe { ManuallyDrop::drop(&mut self.data) };
+ let result = unsafe { self.data.assume_init_mut().drop_checked() };
+ unsafe { MaybeUninit::assume_init_drop(&mut self.data) };
Ok(result)
} else {
Err(T::LIVENESS_FAIL)
@@ -189,7 +207,3 @@ pub fn drop_checked<T: ZludaObject>(handle: T::CudaHandle) -> Result<(), CUerror
unsafe { ManuallyDrop::drop(&mut wrapped_object) };
underlying_error
}
-
-pub(crate) fn init(flags: ::core::ffi::c_uint) -> hipError_t {
- unsafe { hipInit(flags) }
-}
diff --git a/zluda/src/impl/pointer.rs b/zluda/src/impl/pointer.rs
index 1eef540..6b458a0 100644
--- a/zluda/src/impl/pointer.rs
+++ b/zluda/src/impl/pointer.rs
@@ -6,28 +6,27 @@ pub(crate) unsafe fn get_attribute(
data: *mut c_void,
attribute: hipPointer_attribute,
ptr: hipDeviceptr_t,
-) -> CUresult {
+) -> hipError_t {
if data == ptr::null_mut() {
- return CUresult::ERROR_INVALID_VALUE;
+ return hipError_t::ErrorInvalidValue;
}
- // TODO: implement by getting device ordinal & allocation start,
- // then go through every context for that device
- if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT {
- return CUresult::ERROR_NOT_SUPPORTED;
+ match attribute {
+ // TODO: implement by getting device ordinal & allocation start,
+ // then go through every context for that device
+ hipPointer_attribute::HIP_POINTER_ATTRIBUTE_CONTEXT => hipError_t::ErrorNotSupported,
+ hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE => {
+ let mut hip_result = hipMemoryType(0);
+ hipPointerGetAttribute(
+ (&mut hip_result as *mut hipMemoryType).cast::<c_void>(),
+ attribute,
+ ptr,
+ )?;
+ let cuda_result = memory_type(hip_result)?;
+ unsafe { *(data.cast()) = cuda_result };
+ Ok(())
+ }
+ _ => unsafe { hipPointerGetAttribute(data, attribute, ptr) },
}
- if attribute == hipPointer_attribute::HIP_POINTER_ATTRIBUTE_MEMORY_TYPE {
- let mut hip_result = hipMemoryType(0);
- hipPointerGetAttribute(
- (&mut hip_result as *mut hipMemoryType).cast::<c_void>(),
- attribute,
- ptr,
- )?;
- let cuda_result = memory_type(hip_result)?;
- *(data as _) = cuda_result;
- } else {
- hipPointerGetAttribute(data, attribute, ptr)?;
- }
- Ok(())
}
fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
@@ -36,6 +35,6 @@ fn memory_type(cu: hipMemoryType) -> Result<CUmemorytype, hipErrorCode_t> {
hipMemoryType::hipMemoryTypeDevice => Ok(CUmemorytype::CU_MEMORYTYPE_DEVICE),
hipMemoryType::hipMemoryTypeArray => Ok(CUmemorytype::CU_MEMORYTYPE_ARRAY),
hipMemoryType::hipMemoryTypeUnified => Ok(CUmemorytype::CU_MEMORYTYPE_UNIFIED),
- _ => Err(hipErrorCode_t::hipErrorInvalidValue),
+ _ => Err(hipErrorCode_t::InvalidValue),
}
}
diff --git a/zluda/src/impl/test.rs b/zluda/src/impl/test.rs
deleted file mode 100644
index b36ccd8..0000000
--- a/zluda/src/impl/test.rs
+++ /dev/null
@@ -1,157 +0,0 @@
-#![allow(non_snake_case)]
-
-use crate::cuda as zluda;
-use crate::cuda::CUstream;
-use crate::cuda::CUuuid;
-use crate::{
- cuda::{CUdevice, CUdeviceptr},
- r#impl::CUresult,
-};
-use ::std::{
- ffi::c_void,
- os::raw::{c_int, c_uint},
-};
-use cuda_driver_sys as cuda;
-
-#[macro_export]
-macro_rules! cuda_driver_test {
- ($func:ident) => {
- paste! {
- #[test]
- fn [<$func _zluda>]() {
- $func::<crate::r#impl::test::Zluda>()
- }
-
- #[test]
- fn [<$func _cuda>]() {
- $func::<crate::r#impl::test::Cuda>()
- }
- }
- };
-}
-
-pub trait CudaDriverFns {
- fn cuInit(flags: c_uint) -> CUresult;
- fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult;
- fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult;
- fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult;
- fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult;
- fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult;
- fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult;
- fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
- fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
- fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult;
- fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult;
- fn cuMemFree_v2(mem: *mut c_void) -> CUresult;
- fn cuStreamDestroy_v2(stream: CUstream) -> CUresult;
-}
-
-pub struct Zluda();
-
-impl CudaDriverFns for Zluda {
- fn cuInit(_flags: c_uint) -> CUresult {
- zluda::cuInit(_flags as _)
- }
-
- fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
- zluda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev))
- }
-
- fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
- zluda::cuCtxDestroy_v2(ctx as *mut _)
- }
-
- fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
- zluda::cuCtxPopCurrent_v2(pctx as *mut _)
- }
-
- fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
- zluda::cuCtxGetApiVersion(ctx as *mut _, version)
- }
-
- fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
- zluda::cuCtxGetCurrent(pctx as *mut _)
- }
- fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
- zluda::cuMemAlloc_v2(dptr as *mut _, bytesize)
- }
-
- fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
- zluda::cuDeviceGetUuid(uuid, CUdevice(dev))
- }
-
- fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
- zluda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active)
- }
-
- fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
- zluda::cuStreamGetCtx(hStream, pctx as _)
- }
-
- fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
- zluda::cuStreamCreate(stream, flags)
- }
-
- fn cuMemFree_v2(dptr: *mut c_void) -> CUresult {
- zluda::cuMemFree_v2(CUdeviceptr(dptr as _))
- }
-
- fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
- zluda::cuStreamDestroy_v2(stream)
- }
-}
-
-pub struct Cuda();
-
-impl CudaDriverFns for Cuda {
- fn cuInit(flags: c_uint) -> CUresult {
- unsafe { CUresult(cuda::cuInit(flags) as c_uint) }
- }
-
- fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
- unsafe { CUresult(cuda::cuCtxCreate_v2(pctx as *mut _, flags, dev) as c_uint) }
- }
-
- fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
- unsafe { CUresult(cuda::cuCtxDestroy_v2(ctx as *mut _) as c_uint) }
- }
-
- fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
- unsafe { CUresult(cuda::cuCtxPopCurrent_v2(pctx as *mut _) as c_uint) }
- }
-
- fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
- unsafe { CUresult(cuda::cuCtxGetApiVersion(ctx as *mut _, version) as c_uint) }
- }
-
- fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
- unsafe { CUresult(cuda::cuCtxGetCurrent(pctx as *mut _) as c_uint) }
- }
- fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
- unsafe { CUresult(cuda::cuMemAlloc_v2(dptr as *mut _, bytesize) as c_uint) }
- }
-
- fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
- unsafe { CUresult(cuda::cuDeviceGetUuid(uuid as *mut _, dev) as c_uint) }
- }
-
- fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
- unsafe { CUresult(cuda::cuDevicePrimaryCtxGetState(dev, flags, active) as c_uint) }
- }
-
- fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
- unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) }
- }
-
- fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
- unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) }
- }
-
- fn cuMemFree_v2(mem: *mut c_void) -> CUresult {
- unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) }
- }
-
- fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
- unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) }
- }
-}
diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs
index 12d6ce0..bda67e1 100644
--- a/zluda/src/lib.rs
+++ b/zluda/src/lib.rs
@@ -27,10 +27,25 @@ macro_rules! implemented {
};
}
+macro_rules! implemented_in_function {
+ ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path;)*) => {
+ $(
+ #[cfg_attr(not(test), no_mangle)]
+ #[allow(improper_ctypes)]
+ #[allow(improper_ctypes_definitions)]
+ pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
+ cuda_base::cuda_normalize_fn!( crate::r#impl::function::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda(&$arg_id)?),*)?;
+ Ok(())
+ }
+ )*
+ };
+}
+
cuda_base::cuda_function_declarations!(
unimplemented,
implemented <= [
cuCtxGetLimit,
+ cuCtxSetCurrent,
cuCtxSetLimit,
cuCtxSynchronize,
cuDeviceComputeCapability,
@@ -39,13 +54,25 @@ cuda_base::cuda_function_declarations!(
cuDeviceGetCount,
cuDeviceGetLuid,
cuDeviceGetName,
+ cuDevicePrimaryCtxRelease,
+ cuDevicePrimaryCtxRetain,
cuDeviceGetProperties,
cuDeviceGetUuid,
cuDeviceGetUuid_v2,
cuDeviceTotalMem_v2,
+ cuDriverGetVersion,
+ cuFuncGetAttribute,
cuInit,
+ cuMemAlloc_v2,
+ cuMemFree_v2,
+ cuMemcpyDtoH_v2,
+ cuMemcpyHtoD_v2,
cuModuleGetFunction,
cuModuleLoadData,
cuModuleUnload,
+ cuPointerGetAttribute,
+ ],
+ implemented_in_function <= [
+ cuLaunchKernel,
]
); \ No newline at end of file