aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--cuda_base/src/lib.rs47
-rw-r--r--cuda_types/Cargo.toml1
-rw-r--r--cuda_types/src/lib.rs5
-rw-r--r--zluda/Cargo.toml1
-rw-r--r--zluda/src/impl/mod.rs228
-rw-r--r--zluda/src/lib.rs46
-rw-r--r--zluda_bindgen/src/main.rs7
7 files changed, 109 insertions, 226 deletions
diff --git a/cuda_base/src/lib.rs b/cuda_base/src/lib.rs
index 64e33ef..366edd7 100644
--- a/cuda_base/src/lib.rs
+++ b/cuda_base/src/lib.rs
@@ -1,6 +1,7 @@
extern crate proc_macro;
use proc_macro::TokenStream;
+use proc_macro2::Span;
use quote::{quote, ToTokens};
use rustc_hash::FxHashMap;
use std::iter;
@@ -148,3 +149,49 @@ impl VisitMut for FixFnSignatures {
s.inputs.pop_punct();
}
}
+
+#[proc_macro]
+pub fn cuda_normalize_fn(tokens: TokenStream) -> TokenStream {
+ let mut path = parse_macro_input!(tokens as syn::Path);
+ let fn_ = path
+ .segments
+ .pop()
+ .unwrap()
+ .into_tuple()
+ .0
+ .ident
+ .to_string();
+ let known_modules = [
+ "context", "device", "function", "link", "memory", "module", "pointer",
+ ];
+ let segments: Vec<String> = split(&fn_[2..]);
+ let fn_path = join(segments, &known_modules);
+ quote! {
+ #path #fn_path
+ }
+ .into()
+}
+
+fn split(fn_: &str) -> Vec<String> {
+ let mut result = Vec::new();
+ for c in fn_.chars() {
+ if c.is_ascii_uppercase() {
+ result.push(c.to_ascii_lowercase().to_string());
+ } else {
+ result.last_mut().unwrap().push(c);
+ }
+ }
+ result
+}
+
+fn join(fn_: Vec<String>, known_modules: &[&str]) -> Punctuated<Ident, Token![::]> {
+ let (prefix, suffix) = fn_.split_at(1);
+ if known_modules.contains(&&*prefix[0]) {
+ [&prefix[0], &suffix.join("_")]
+ .into_iter()
+ .map(|seg| Ident::new(seg, Span::call_site()))
+ .collect()
+ } else {
+ iter::once(Ident::new(&fn_.join("_"), Span::call_site())).collect()
+ }
+}
diff --git a/cuda_types/Cargo.toml b/cuda_types/Cargo.toml
index e779830..2ca470f 100644
--- a/cuda_types/Cargo.toml
+++ b/cuda_types/Cargo.toml
@@ -6,3 +6,4 @@ edition = "2018"
[dependencies]
cuda_base = { path = "../cuda_base" }
+hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
diff --git a/cuda_types/src/lib.rs b/cuda_types/src/lib.rs
index bd350e4..945c0a7 100644
--- a/cuda_types/src/lib.rs
+++ b/cuda_types/src/lib.rs
@@ -8083,3 +8083,8 @@ pub type CUresult = ::core::result::Result<(), CUerror>;
const _: fn() = || {
let _ = std::mem::transmute::<CUresult, u32>;
};
+impl From<hip_runtime_sys::hipErrorCode_t> for CUerror {
+ fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
+ Self(error.0)
+ }
+}
diff --git a/zluda/Cargo.toml b/zluda/Cargo.toml
index 0092430..ab87b6c 100644
--- a/zluda/Cargo.toml
+++ b/zluda/Cargo.toml
@@ -11,6 +11,7 @@ crate-type = ["cdylib"]
[dependencies]
ptx = { path = "../ptx" }
cuda_types = { path = "../cuda_types" }
+cuda_base = { path = "../cuda_base" }
hip_runtime-sys = { path = "../ext/hip_runtime-sys" }
lazy_static = "1.4"
num_enum = "0.4"
diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs
index 1335ef6..03a68d8 100644
--- a/zluda/src/impl/mod.rs
+++ b/zluda/src/impl/mod.rs
@@ -1,230 +1,26 @@
-use hip_runtime_sys::hipError_t;
-
-use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
-use std::{
- ffi::c_void,
- mem::{self, ManuallyDrop},
- os::raw::c_int,
- ptr,
- sync::Mutex,
- sync::TryLockError,
-};
-
-#[cfg(test)]
-#[macro_use]
-pub mod test;
-pub mod device;
-pub mod export_table;
-pub mod function;
-#[cfg_attr(windows, path = "os_win.rs")]
-#[cfg_attr(not(windows), path = "os_unix.rs")]
-pub(crate) mod os;
-pub(crate) mod module;
-pub(crate) mod context;
-pub(crate) mod memory;
-pub(crate) mod link;
-pub(crate) mod pointer;
+use cuda_types::*;
+use hip_runtime_sys::*;
#[cfg(debug_assertions)]
-pub fn unimplemented() -> CUresult {
+pub(crate) fn unimplemented() -> CUresult {
unimplemented!()
}
#[cfg(not(debug_assertions))]
-pub fn unimplemented() -> CUresult {
- CUresult::CUDA_ERROR_NOT_SUPPORTED
-}
-
-#[macro_export]
-macro_rules! hip_call {
- ($expr:expr) => {
- #[allow(unused_unsafe)]
- {
- let err = unsafe { $expr };
- if err != hip_runtime_sys::hipError_t::hipSuccess {
- return Result::Err(err);
- }
- }
- };
-}
-
-pub trait HasLivenessCookie: Sized {
- const COOKIE: usize;
- const LIVENESS_FAIL: CUresult;
-
- fn try_drop(&mut self) -> Result<(), CUresult>;
-}
-
-// This struct is a best-effort check if wrapped value has been dropped,
-// while it's inherently safe, its use coming from FFI is very unsafe
-#[repr(C)]
-pub struct LiveCheck<T: HasLivenessCookie> {
- cookie: usize,
- data: ManuallyDrop<T>,
-}
-
-impl<T: HasLivenessCookie> LiveCheck<T> {
- pub fn new(data: T) -> Self {
- LiveCheck {
- cookie: T::COOKIE,
- data: ManuallyDrop::new(data),
- }
- }
-
- fn destroy_impl(this: *mut Self) -> Result<(), CUresult> {
- let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) });
- ctx_box.try_drop()?;
- unsafe { ManuallyDrop::drop(&mut ctx_box) };
- Ok(())
- }
-
- unsafe fn ptr_from_inner(this: *mut T) -> *mut Self {
- let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>());
- outer_ptr as *mut Self
- }
-
- pub unsafe fn as_ref_unchecked(&self) -> &T {
- &self.data
- }
-
- pub fn as_option_mut(&mut self) -> Option<&mut T> {
- if self.cookie == T::COOKIE {
- Some(&mut self.data)
- } else {
- None
- }
- }
-
- pub fn as_result(&self) -> Result<&T, CUresult> {
- if self.cookie == T::COOKIE {
- Ok(&self.data)
- } else {
- Err(T::LIVENESS_FAIL)
- }
- }
-
- pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> {
- if self.cookie == T::COOKIE {
- Ok(&mut self.data)
- } else {
- Err(T::LIVENESS_FAIL)
- }
- }
-
- #[must_use]
- pub fn try_drop(&mut self) -> Result<(), CUresult> {
- if self.cookie == T::COOKIE {
- self.cookie = 0;
- self.data.try_drop()?;
- unsafe { ManuallyDrop::drop(&mut self.data) };
- return Ok(());
- }
- Err(T::LIVENESS_FAIL)
- }
-}
-
-impl<T: HasLivenessCookie> Drop for LiveCheck<T> {
- fn drop(&mut self) {
- self.cookie = 0;
- }
-}
-
-pub trait CudaRepr: Sized {
- type Impl: Sized;
-}
-
-impl<T: CudaRepr> CudaRepr for *mut T {
- type Impl = *mut T::Impl;
+pub(crate) fn unimplemented() -> CUresult {
+ CUresult::ERROR_NOT_SUPPORTED
}
-pub trait Decuda<To> {
- fn decuda(self: Self) -> To;
+pub(crate) trait FromCuda<T>: Sized {
+ fn from_cuda(t: T) -> Result<Self, CUerror>;
}
-impl<T: CudaRepr> Decuda<*mut T::Impl> for *mut T {
- fn decuda(self: Self) -> *mut T::Impl {
- self as *mut _
+impl FromCuda<u32> for u32 {
+ fn from_cuda(x: u32) -> Result<Self, CUerror> {
+ Ok(x)
}
}
-impl<T> From<TryLockError<T>> for CUresult {
- fn from(_: TryLockError<T>) -> Self {
- CUresult::CUDA_ERROR_ILLEGAL_STATE
- }
-}
-
-impl From<ocl_core::Error> for CUresult {
- fn from(result: ocl_core::Error) -> Self {
- match result {
- _ => CUresult::CUDA_ERROR_UNKNOWN,
- }
- }
-}
-
-impl From<hip_runtime_sys::hipError_t> for CUresult {
- fn from(result: hip_runtime_sys::hipError_t) -> Self {
- match result {
- hip_runtime_sys::hipError_t::hipErrorRuntimeMemory
- | hip_runtime_sys::hipError_t::hipErrorRuntimeOther => CUresult::CUDA_ERROR_UNKNOWN,
- hip_runtime_sys::hipError_t(e) => CUresult(e),
- }
- }
-}
-
-pub trait Encuda {
- type To: Sized;
- fn encuda(self: Self) -> Self::To;
-}
-
-impl Encuda for CUresult {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- self
- }
-}
-
-impl Encuda for () {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- CUresult::CUDA_SUCCESS
- }
-}
-
-impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1, T2> {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- match self {
- Ok(e) => e.encuda(),
- Err(e) => e.encuda(),
- }
- }
-}
-
-impl Encuda for hipError_t {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- self.into()
- }
-}
-
-unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
- mem::transmute(t)
-}
-
-unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T {
- mem::transmute(t)
-}
-
-pub fn driver_get_version() -> c_int {
- i32::max_value()
-}
-
-impl<'a> CudaRepr for CUdeviceptr {
- type Impl = *mut c_void;
-}
-
-impl Decuda<*mut c_void> for CUdeviceptr {
- fn decuda(self) -> *mut c_void {
- self.0 as *mut _
- }
+pub(crate) fn init(flags: ::core::ffi::c_uint) -> hipError_t {
+ unsafe { hipInit(flags) }
}
diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs
index 9986051..9a6c402 100644
--- a/zluda/src/lib.rs
+++ b/zluda/src/lib.rs
@@ -1,11 +1,37 @@
-extern crate lazy_static;
-#[cfg(test)]
-extern crate cuda_driver_sys;
-#[cfg(test)]
-extern crate paste;
-extern crate ptx;
-
-#[allow(warnings)]
-pub mod cuda;
-mod cuda_impl;
pub(crate) mod r#impl;
+
+macro_rules! unimplemented {
+ ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path;)*) => {
+ $(
+ #[cfg_attr(not(test), no_mangle)]
+ #[allow(improper_ctypes)]
+ #[allow(improper_ctypes_definitions)]
+ pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
+ crate::r#impl::unimplemented()
+ }
+ )*
+ };
+}
+
+macro_rules! implemented {
+ ($($abi:literal fn $fn_name:ident( $($arg_id:ident : $arg_type:ty),* ) -> $ret_type:path;)*) => {
+ $(
+ #[cfg_attr(not(test), no_mangle)]
+ #[allow(improper_ctypes)]
+ #[allow(improper_ctypes_definitions)]
+ pub unsafe extern $abi fn $fn_name ( $( $arg_id : $arg_type),* ) -> $ret_type {
+ cuda_base::cuda_normalize_fn!( crate::r#impl::$fn_name ) ($(crate::r#impl::FromCuda::from_cuda($arg_id)?),*)?;
+ Ok(())
+ }
+ )*
+ };
+}
+
+
+use cuda_base::cuda_function_declarations;
+cuda_function_declarations!(
+ unimplemented,
+ implemented <= [
+ cuInit
+ ]
+); \ No newline at end of file
diff --git a/zluda_bindgen/src/main.rs b/zluda_bindgen/src/main.rs
index ebb357c..b7c7dac 100644
--- a/zluda_bindgen/src/main.rs
+++ b/zluda_bindgen/src/main.rs
@@ -183,6 +183,7 @@ impl ConvertIntoRustResult {
#[repr(transparent)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct #new_error_type(pub ::core::num::NonZeroU32);
+
pub trait #type_trait {
#(#result_variants)*
}
@@ -192,6 +193,12 @@ impl ConvertIntoRustResult {
const _: fn() = || {
let _ = std::mem::transmute::<#type_, u32>;
};
+
+ impl From<hip_runtime_sys::hipErrorCode_t> for #new_error_type {
+ fn from(error: hip_runtime_sys::hipErrorCode_t) -> Self {
+ Self(error.0)
+ }
+ }
};
items.extend(extra_items);
}