summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--level_zero/src/ze.rs12
-rw-r--r--notcuda/build.rs27
-rw-r--r--notcuda/src/cuda.rs25
-rw-r--r--notcuda/src/impl/context.rs191
-rw-r--r--notcuda/src/impl/device.rs283
-rw-r--r--notcuda/src/impl/export_table.rs83
-rw-r--r--notcuda/src/impl/function.rs66
-rw-r--r--notcuda/src/impl/memory.rs83
-rw-r--r--notcuda/src/impl/mod.rs161
-rw-r--r--notcuda/src/impl/module.rs206
-rw-r--r--notcuda/src/impl/stream.rs203
-rw-r--r--notcuda/src/impl/test.rs57
-rw-r--r--ptx/src/lib.rs5
-rw-r--r--ptx/src/test/mod.rs2
-rw-r--r--ptx/src/translate.rs30
15 files changed, 904 insertions, 530 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs
index f8a2c3b..4267682 100644
--- a/level_zero/src/ze.rs
+++ b/level_zero/src/ze.rs
@@ -173,6 +173,16 @@ impl Context {
check!(sys::zeContextCreate(drv.0, &ctx_desc, &mut result));
Ok(Context(result))
}
+
+ pub unsafe fn mem_free(&mut self, ptr: *mut c_void) -> Result<()> {
+ check! {
+ sys::zeMemFree(
+ self.0,
+ ptr,
+ )
+ };
+ Ok(())
+ }
}
impl Drop for Context {
@@ -239,7 +249,7 @@ pub struct Module(sys::ze_module_handle_t);
impl Module {
// HACK ALERT
- // We use OpenCL for now to do SPIR-V linking, because Level0
+ // We use OpenCL for now to do SPIR-V linking, because Level0
// does not allow linking. Don't let presence of zeModuleDynamicLink fool
// you, it's not currently possible to create non-compiled modules.
// zeModuleCreate always compiles (builds and links).
diff --git a/notcuda/build.rs b/notcuda/build.rs
new file mode 100644
index 0000000..3b8999f
--- /dev/null
+++ b/notcuda/build.rs
@@ -0,0 +1,27 @@
+// HACK ALERT
+// This buidl script has been copy-pasted from cl-sys to avoid CUDA libraries
+// overriding path to OpenCL
+
+ fn main() {
+ if cfg!(windows) {
+ let known_sdk = [
+ // E.g. "c:\Program Files (x86)\Intel\OpenCL SDK\lib\x86\"
+ ("INTELOCLSDKROOT", "x64", "x86"),
+ // E.g. "c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\Win32\"
+ ("CUDA_PATH", "x64", "Win32"),
+ // E.g. "C:\Program Files (x86)\AMD APP SDK\3.0\lib\x86\"
+ ("AMDAPPSDKROOT", "x86_64", "x86"),
+ ];
+
+ for info in known_sdk.iter() {
+ if let Ok(sdk) = std::env::var(info.0) {
+ let mut path = std::path::PathBuf::from(sdk);
+ path.push("lib");
+ path.push(if cfg!(target_arch="x86_64") { info.1 } else { info.2 });
+ println!("cargo:rustc-link-search=native={}", path.display());
+ }
+ }
+
+ println!("cargo:rustc-link-search=native=C:\\Program Files (x86)\\OCL_SDK_Light\\lib\\x86_64");
+ }
+} \ No newline at end of file
diff --git a/notcuda/src/cuda.rs b/notcuda/src/cuda.rs
index a18ebf9..335da4a 100644
--- a/notcuda/src/cuda.rs
+++ b/notcuda/src/cuda.rs
@@ -2210,12 +2210,12 @@ pub extern "C" fn cuDriverGetVersion(driverVersion: *mut ::std::os::raw::c_int)
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDeviceGet(device: *mut CUdevice, ordinal: ::std::os::raw::c_int) -> CUresult {
- r#impl::device::get(device.decuda(), ordinal)
+ r#impl::device::get(device.decuda(), ordinal).encuda()
}
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDeviceGetCount(count: *mut ::std::os::raw::c_int) -> CUresult {
- r#impl::device::get_count(count)
+ r#impl::device::get_count(count).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -2314,7 +2314,6 @@ pub extern "C" fn cuDevicePrimaryCtxReset(dev: CUdevice) -> CUresult {
cuDevicePrimaryCtxReset_v2(dev)
}
-
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDevicePrimaryCtxReset_v2(dev: CUdevice) -> CUresult {
r#impl::unimplemented()
@@ -2331,7 +2330,7 @@ pub extern "C" fn cuCtxCreate_v2(
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxDestroy_v2(ctx: CUcontext) -> CUresult {
- r#impl::context::destroy_v2(ctx.decuda())
+ r#impl::context::destroy_v2(ctx.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -2356,7 +2355,7 @@ pub extern "C" fn cuCtxGetCurrent(pctx: *mut CUcontext) -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxGetDevice(device: *mut CUdevice) -> CUresult {
- r#impl::context::get_device(device.decuda())
+ r#impl::context::get_device(device.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -2404,7 +2403,7 @@ pub extern "C" fn cuCtxGetApiVersion(
ctx: CUcontext,
version: *mut ::std::os::raw::c_uint,
) -> CUresult {
- r#impl::context::get_api_version(ctx.decuda(), version)
+ r#impl::context::get_api_version(ctx.decuda(), version).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -2422,12 +2421,12 @@ pub extern "C" fn cuCtxResetPersistingL2Cache() -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxAttach(pctx: *mut CUcontext, flags: ::std::os::raw::c_uint) -> CUresult {
- r#impl::unimplemented()
+ r#impl::context::attach(pctx.decuda(), flags).encuda()
}
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuCtxDetach(ctx: CUcontext) -> CUresult {
- r#impl::unimplemented()
+ r#impl::context::detach(ctx.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -2443,7 +2442,7 @@ pub extern "C" fn cuModuleLoadData(
module: *mut CUmodule,
image: *const ::std::os::raw::c_void,
) -> CUresult {
- r#impl::unimplemented()
+ r#impl::module::load_data(module.decuda(), image).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -2564,7 +2563,7 @@ pub extern "C" fn cuMemGetInfo_v2(free: *mut usize, total: *mut usize) -> CUresu
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuMemAlloc_v2(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult {
- r#impl::memory::alloc_v2(dptr.decuda(), bytesize)
+ r#impl::memory::alloc_v2(dptr.decuda(), bytesize).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -3281,7 +3280,7 @@ pub extern "C" fn cuStreamCreate(
phStream: *mut CUstream,
Flags: ::std::os::raw::c_uint,
) -> CUresult {
- r#impl::unimplemented()
+ r#impl::stream::create(phStream.decuda(), Flags).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -3311,7 +3310,7 @@ pub extern "C" fn cuStreamGetFlags(
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuStreamGetCtx(hStream: CUstream, pctx: *mut CUcontext) -> CUresult {
- r#impl::unimplemented()
+ r#impl::stream::get_ctx(hStream.decuda(), pctx.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
@@ -3390,7 +3389,7 @@ pub extern "C" fn cuStreamSynchronize(hStream: CUstream) -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuStreamDestroy_v2(hStream: CUstream) -> CUresult {
- r#impl::unimplemented()
+ r#impl::stream::destroy_v2(hStream.decuda()).encuda()
}
#[cfg_attr(not(test), no_mangle)]
diff --git a/notcuda/src/impl/context.rs b/notcuda/src/impl/context.rs
index 91d4460..9689ecf 100644
--- a/notcuda/src/impl/context.rs
+++ b/notcuda/src/impl/context.rs
@@ -1,18 +1,15 @@
-use super::CUresult;
-use super::{device, HasLivenessCookie, LiveCheck};
+use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck};
+use super::{CUresult, GlobalState};
use crate::{cuda::CUcontext, cuda_impl};
use l0::sys::ze_result_t;
-use std::mem::{self, ManuallyDrop};
+use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32};
use std::{
- cell::RefCell,
- num::NonZeroU32,
- os::raw::c_uint,
- ptr,
- sync::{atomic::AtomicU32, Mutex},
+ collections::HashSet,
+ mem::{self},
};
thread_local! {
- pub static CONTEXT_STACK: RefCell<Vec<*const Context>> = RefCell::new(Vec::new());
+ pub static CONTEXT_STACK: RefCell<Vec<*mut Context>> = RefCell::new(Vec::new());
}
pub type Context = LiveCheck<ContextData>;
@@ -23,6 +20,17 @@ impl HasLivenessCookie for ContextData {
#[cfg(target_pointer_width = "32")]
const COOKIE: usize = 0x0b643ffb;
+
+ const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT;
+
+ fn try_drop(&mut self) -> Result<(), CUresult> {
+ for stream in self.streams.iter() {
+ let stream = unsafe { &mut **stream };
+ stream.context = ptr::null_mut();
+ Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?;
+ }
+ Ok(())
+ }
}
enum ContextRefCount {
@@ -67,26 +75,16 @@ impl ContextRefCount {
}
}
}
-
- fn is_primary(&self) -> bool {
- match self {
- ContextRefCount::Primary => true,
- ContextRefCount::NonPrimary(_) => false,
- }
- }
}
pub struct ContextData {
pub flags: AtomicU32,
- pub device_index: device::Index,
// This pointer is null only for a moment when constructing primary context
- pub device: *const Mutex<device::Device>,
- // The split between mutable / non-mutable is mainly to avoid recursive locking in cuDevicePrimaryCtxGetState
- pub mutable: Mutex<ContextDataMutable>,
-}
-
-pub struct ContextDataMutable {
+ pub device: *mut device::Device,
ref_count: ContextRefCount,
+ pub default_stream: StreamData,
+ pub streams: HashSet<*mut StreamData>,
+ // All the fields below are here to support internal CUDA driver API
pub cuda_manager: *mut cuda_impl::rt::ContextStateManager,
pub cuda_state: *mut cuda_impl::rt::ContextState,
pub cuda_dtor_cb: Option<
@@ -100,63 +98,75 @@ pub struct ContextDataMutable {
impl ContextData {
pub fn new(
+ l0_ctx: &mut l0::Context,
+ l0_dev: &l0::Device,
flags: c_uint,
is_primary: bool,
- dev_index: device::Index,
- dev: *const Mutex<device::Device>,
- ) -> Self {
- ContextData {
+ dev: *mut device::Device,
+ ) -> Result<Self, CUresult> {
+ let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?;
+ Ok(ContextData {
flags: AtomicU32::new(flags),
- device_index: dev_index,
device: dev,
- mutable: Mutex::new(ContextDataMutable {
- ref_count: ContextRefCount::new(is_primary),
- cuda_manager: ptr::null_mut(),
- cuda_state: ptr::null_mut(),
- cuda_dtor_cb: None,
- }),
- }
+ ref_count: ContextRefCount::new(is_primary),
+ default_stream,
+ streams: HashSet::new(),
+ cuda_manager: ptr::null_mut(),
+ cuda_state: ptr::null_mut(),
+ cuda_dtor_cb: None,
+ })
}
}
-pub fn create_v2(pctx: *mut *mut Context, flags: u32, dev_idx: device::Index) -> CUresult {
+impl Context {
+ pub fn late_init(&mut self) {
+ let ctx_data = self.as_option_mut().unwrap();
+ ctx_data.default_stream.context = ctx_data as *mut _;
+ }
+}
+
+pub fn create_v2(
+ pctx: *mut *mut Context,
+ flags: u32,
+ dev_idx: device::Index,
+) -> Result<(), CUresult> {
if pctx == ptr::null_mut() {
- return CUresult::CUDA_ERROR_INVALID_VALUE;
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let dev = device::get_device_ref(dev_idx);
- let dev = match dev {
- Ok(d) => d,
- Err(e) => return e,
- };
- let mut ctx = Box::new(LiveCheck::new(ContextData::new(flags, false, dev_idx, dev)));
- let ctx_ref = ctx.as_mut() as *mut Context;
+ let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| {
+ let dev_ptr = dev as *mut _;
+ let mut ctx_box = Box::new(LiveCheck::new(ContextData::new(
+ &mut dev.l0_context,
+ &dev.base,
+ flags,
+ false,
+ dev_ptr as *mut _,
+ )?));
+ ctx_box.late_init();
+ Ok::<_, CUresult>(ctx_box)
+ })??;
+ let ctx_ref = ctx_box.as_mut() as *mut Context;
unsafe { *pctx = ctx_ref };
- mem::forget(ctx);
+ mem::forget(ctx_box);
CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref));
- CUresult::CUDA_SUCCESS
+ Ok(())
}
-pub fn destroy_v2(ctx: *mut Context) -> CUresult {
+pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> {
if ctx == ptr::null_mut() {
- return CUresult::CUDA_ERROR_INVALID_VALUE;
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
let should_pop = match stack.last() {
- Some(active_ctx) => *active_ctx == (ctx as *const _),
+ Some(active_ctx) => *active_ctx == (ctx as *mut _),
None => false,
};
if should_pop {
stack.pop();
}
});
- let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(ctx) });
- if !ctx_box.try_drop() {
- CUresult::CUDA_ERROR_INVALID_CONTEXT
- } else {
- unsafe { ManuallyDrop::drop(&mut ctx_box) };
- CUresult::CUDA_SUCCESS
- }
+ GlobalState::lock(|_| Context::destroy_impl(ctx))?
}
pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
@@ -172,17 +182,6 @@ pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult {
CUresult::CUDA_SUCCESS
}
-pub fn with_current<F: FnOnce(&ContextData) -> R, R>(f: F) -> Result<R, CUresult> {
- CONTEXT_STACK.with(|stack| {
- stack
- .borrow()
- .last()
- .and_then(|c| unsafe { &**c }.as_ref())
- .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
- .map(f)
- })
-}
-
pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> {
if pctx == ptr::null_mut() {
return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT);
@@ -205,37 +204,53 @@ pub fn set_current(ctx: *mut Context) -> CUresult {
}
}
-pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> CUresult {
- let _ctx = match unsafe { ctx.as_mut() } {
- None => return CUresult::CUDA_ERROR_INVALID_VALUE,
- Some(ctx) => match ctx.as_mut() {
- None => return CUresult::CUDA_ERROR_INVALID_CONTEXT,
- Some(ctx) => ctx,
- },
- };
+pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
+ if ctx == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ GlobalState::lock(|_| {
+ unsafe { &*ctx }.as_result()?;
+ Ok::<_, CUresult>(())
+ })??;
//TODO: query device for properties roughly matching CUDA API version
unsafe { *version = 1100 };
- CUresult::CUDA_SUCCESS
+ Ok(())
}
-pub fn get_device(dev: *mut device::Index) -> CUresult {
- let dev_idx = with_current(|ctx| ctx.device_index);
- match dev_idx {
- Ok(idx) => {
- unsafe { *dev = idx }
- CUresult::CUDA_SUCCESS
- }
- Err(err) => err,
+pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> {
+ let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?;
+ unsafe { *dev = dev_idx };
+ Ok(())
+}
+
+pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> {
+ if pctx == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
+ let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
+ let ctx = unchecked_ctx.as_result_mut()?;
+ ctx.ref_count.incr()?;
+ Ok::<_, CUresult>(unchecked_ctx as *mut _)
+ })??;
+ unsafe { *pctx = ctx };
+ Ok(())
}
-#[cfg(test)]
-pub fn is_context_stack_empty() -> bool {
- CONTEXT_STACK.with(|stack| stack.borrow().is_empty())
+pub fn detach(pctx: *mut Context) -> Result<(), CUresult> {
+ if pctx == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ GlobalState::lock_current_context_unchecked(|unchecked_ctx| {
+ let ctx = unchecked_ctx.as_result_mut()?;
+ if ctx.ref_count.decr() {
+ Context::destroy_impl(unchecked_ctx)?;
+ }
+ Ok::<_, CUresult>(())
+ })?
}
#[cfg(test)]
-mod tests {
+mod test {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
use std::{ffi::c_void, ptr};
diff --git a/notcuda/src/impl/device.rs b/notcuda/src/impl/device.rs
index d4859d3..b8d263d 100644
--- a/notcuda/src/impl/device.rs
+++ b/notcuda/src/impl/device.rs
@@ -1,24 +1,21 @@
-use super::{context, transmute_lifetime, CUresult, Error};
+use super::{context, CUresult, GlobalState};
use crate::cuda;
use cuda::{CUdevice_attribute, CUuuid_st};
use std::{
cmp, mem,
os::raw::{c_char, c_int},
ptr,
- sync::{
- atomic::{AtomicU32, Ordering},
- Mutex, MutexGuard,
- },
+ sync::atomic::{AtomicU32, Ordering},
};
const PROJECT_URL_SUFFIX: &'static str = " [github.com/vosen/notCUDA]";
-static mut DEVICES: Option<Vec<Mutex<Device>>> = None;
#[repr(transparent)]
-#[derive(Clone, Copy)]
+#[derive(Clone, Copy, Eq, PartialEq, Hash)]
pub struct Index(pub c_int);
pub struct Device {
+ pub index: Index,
pub base: l0::Device,
pub default_queue: l0::CommandQueue,
pub l0_context: l0::Context,
@@ -33,17 +30,19 @@ unsafe impl Send for Device {}
impl Device {
// Unsafe because it does not fully initalize primary_context
- unsafe fn new(drv: &l0::Driver, d: l0::Device, idx: usize) -> l0::Result<Self> {
+ unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> {
let mut ctx = l0::Context::new(drv)?;
- let queue = l0::CommandQueue::new(&mut ctx, &d)?;
+ let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?;
let primary_context = context::Context::new(context::ContextData::new(
+ &mut ctx,
+ &l0_dev,
0,
true,
- Index(idx as c_int),
- ptr::null(),
- ));
+ ptr::null_mut(),
+ )?);
Ok(Self {
- base: d,
+ index: Index(idx as c_int),
+ base: l0_dev,
default_queue: queue,
l0_context: ctx,
primary_context: primary_context,
@@ -93,83 +92,53 @@ impl Device {
Err(e) => Err(e),
}
}
+
+ pub fn late_init(&mut self) {
+ self.primary_context.as_option_mut().unwrap().device = self as *mut _;
+ }
}
-pub fn init(driver: &l0::Driver) -> l0::Result<()> {
+pub fn init(driver: &l0::Driver) -> Result<Vec<Device>, CUresult> {
let ze_devices = driver.devices()?;
let mut devices = ze_devices
.into_iter()
.enumerate()
- .map(|(idx, d)| unsafe { Device::new(driver, d, idx) }.map(Mutex::new))
+ .map(|(idx, d)| unsafe { Device::new(driver, d, idx) })
.collect::<Result<Vec<_>, _>>()?;
- for d in devices.iter_mut() {
- d.get_mut()
- .unwrap()
- .primary_context
- .as_mut()
- .unwrap()
- .device = d;
- }
- unsafe { DEVICES = Some(devices) };
- Ok(())
-}
-
-fn devices() -> Result<&'static Vec<Mutex<Device>>, CUresult> {
- match unsafe { &DEVICES } {
- Some(devs) => Ok(devs),
- None => Err(CUresult::CUDA_ERROR_NOT_INITIALIZED),
- }
-}
-
-pub fn get_device_ref(Index(dev_idx): Index) -> Result<&'static Mutex<Device>, CUresult> {
- let devs = devices()?;
- if dev_idx < 0 || dev_idx >= devs.len() as c_int {
- return Err(CUresult::CUDA_ERROR_INVALID_DEVICE);
+ for dev in devices.iter_mut() {
+ dev.late_init();
+ dev.primary_context.late_init();
}
- Ok(&devs[dev_idx as usize])
+ Ok(devices)
}
-pub fn get_device(dev_idx: Index) -> Result<MutexGuard<'static, Device>, CUresult> {
- let dev = get_device_ref(dev_idx)?;
- dev.lock().map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
-}
-
-pub fn get_count(count: *mut c_int) -> CUresult {
- let len = devices().map(|d| d.len());
- match len {
- Ok(len) => {
- unsafe { *count = len as c_int };
- CUresult::CUDA_SUCCESS
- }
- Err(e) => e,
- }
+pub fn get_count(count: *mut c_int) -> Result<(), CUresult> {
+ let len = GlobalState::lock(|state| state.devices.len())?;
+ unsafe { *count = len as c_int };
+ Ok(())
}
-pub fn get(device: *mut Index, ordinal: c_int) -> CUresult {
+pub fn get(device: *mut Index, ordinal: c_int) -> Result<(), CUresult> {
if device == ptr::null_mut() || ordinal < 0 {
- return CUresult::CUDA_ERROR_INVALID_VALUE;
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let len = devices().map(|d| d.len());
- match len {
- Ok(len) if ordinal < (len as i32) => {
- unsafe { *device = Index(ordinal) };
- CUresult::CUDA_SUCCESS
- }
- Ok(_) => CUresult::CUDA_ERROR_INVALID_VALUE,
- Err(e) => e,
+ let len = GlobalState::lock(|state| state.devices.len())?;
+ if ordinal < (len as i32) {
+ unsafe { *device = Index(ordinal) };
+ Ok(())
+ } else {
+ Err(CUresult::CUDA_ERROR_INVALID_VALUE)
}
}
-pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult> {
+pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUresult> {
if name == ptr::null_mut() || len < 0 {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- // This is safe because devices are 'static
- let name_ptr = {
- let mut dev = get_device(dev)?;
- let props = dev.get_properties().map_err(Into::<CUresult>::into)?;
- props.name.as_ptr()
- };
+ let name_ptr = GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(props.name.as_ptr())
+ })??;
let name_len = (0..256)
.position(|i| unsafe { *name_ptr.add(i) } == 0)
.unwrap_or(256);
@@ -189,20 +158,14 @@ pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult>
Ok(())
}
-pub fn total_mem_v2(bytes: *mut usize, dev: Index) -> Result<(), CUresult> {
+pub fn total_mem_v2(bytes: *mut usize, dev_idx: Index) -> Result<(), CUresult> {
if bytes == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- // This is safe because devices are 'static
- let mem_props = {
- let mut dev = get_device(dev)?;
- unsafe {
- transmute_lifetime(
- dev.get_memory_properties()
- .map_err(Into::<CUresult>::into)?,
- )
- }
- };
+ let mem_props = GlobalState::lock_device(dev_idx, |dev| {
+ let mem_props = dev.get_memory_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(mem_props)
+ })??;
let max_mem = mem_props
.iter()
.map(|p| p.totalSize)
@@ -228,56 +191,101 @@ impl CUdevice_attribute {
}
}
-pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Result<(), Error> {
+pub fn get_attribute(
+ pi: *mut i32,
+ attrib: CUdevice_attribute,
+ dev_idx: Index,
+) -> Result<(), CUresult> {
if pi == ptr::null_mut() {
- return Err(Error::Cuda(CUresult::CUDA_ERROR_INVALID_VALUE));
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
if let Some(value) = attrib.get_static_value() {
unsafe { *pi = value };
return Ok(());
}
- let mut dev = get_device(dev).map_err(Error::Cuda)?;
let value = match attrib {
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => {
- dev.get_properties().map_err(Error::L0)?.maxHardwareContexts as i32
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(props.maxHardwareContexts as i32)
+ })??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT => {
- let props = dev.get_properties().map_err(Error::L0)?;
- (props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(
+ (props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32,
+ )
+ })??
+ }
+ CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => {
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_image_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(cmp::min(
+ props.maxImageDims1D,
+ c_int::max_value() as u32,
+ ) as c_int)
+ })??
}
- CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => cmp::min(
- dev.get_image_properties()
- .map_err(Error::L0)?
- .maxImageDims1D,
- c_int::max_value() as u32,
- ) as c_int,
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X => {
- let props = dev.get_compute_properties().map_err(Error::L0)?;
- cmp::max(i32::max_value() as u32, props.maxGroupCountX) as i32
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_compute_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(cmp::max(
+ i32::max_value() as u32,
+ props.maxGroupCountX,
+ ) as i32)
+ })??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y => {
- let props = dev.get_compute_properties().map_err(Error::L0)?;
- cmp::max(i32::max_value() as u32, props.maxGroupCountY) as i32
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_compute_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(cmp::max(
+ i32::max_value() as u32,
+ props.maxGroupCountY,
+ ) as i32)
+ })??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z => {
- let props = dev.get_compute_properties().map_err(Error::L0)?;
- cmp::max(i32::max_value() as u32, props.maxGroupCountZ) as i32
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_compute_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(cmp::max(
+ i32::max_value() as u32,
+ props.maxGroupCountZ,
+ ) as i32)
+ })??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X => {
- let props = dev.get_compute_properties().map_err(Error::L0)?;
- cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_compute_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(
+ cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32,
+ )
+ })??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y => {
- let props = dev.get_compute_properties().map_err(Error::L0)?;
- cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_compute_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(
+ cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32,
+ )
+ })??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z => {
- let props = dev.get_compute_properties().map_err(Error::L0)?;
- cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_compute_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(
+ cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32,
+ )
+ })??
}
CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
- let props = dev.get_compute_properties().map_err(Error::L0)?;
- cmp::max(i32::max_value() as u32, props.maxTotalGroupSize) as i32
+ GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_compute_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(cmp::max(
+ i32::max_value() as u32,
+ props.maxTotalGroupSize,
+ ) as i32)
+ })??
}
_ => {
// TODO: support more attributes for CUDA runtime
@@ -293,14 +301,11 @@ pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Re
Ok(())
}
-pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> {
- let ze_uuid = {
- get_device(dev)
- .map_err(Error::Cuda)?
- .get_properties()
- .map_err(Error::L0)?
- .uuid
- };
+pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> {
+ let ze_uuid = GlobalState::lock_device(dev_idx, |dev| {
+ let props = dev.get_properties()?;
+ Ok::<_, l0::sys::ze_result_t>(props.uuid)
+ })??;
unsafe {
*uuid = CUuuid_st {
bytes: mem::transmute(ze_uuid.id),
@@ -309,53 +314,39 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> {
Ok(())
}
-pub fn with_current_exclusive<F: FnOnce(&mut Device) -> R, R>(f: F) -> Result<R, CUresult> {
- let dev = super::context::with_current(|ctx| ctx.device);
- dev.and_then(|dev| {
- unsafe { &*dev }
- .try_lock()
- .map(|mut dev| f(&mut dev))
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
- })
-}
-
-pub fn with_exclusive<F: FnOnce(&mut Device) -> R, R>(dev: Index, f: F) -> Result<R, CUresult> {
- let dev = get_device_ref(dev)?;
- dev.try_lock()
- .map(|mut dev| f(&mut dev))
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
-}
-
pub fn primary_ctx_get_state(
- idx: Index,
+ dev_idx: Index,
flags: *mut u32,
active: *mut i32,
) -> Result<(), CUresult> {
- let (ctx_ptr, flags_ptr) = with_exclusive(idx, |dev| {
+ let (is_active, flags_value) = GlobalState::lock_device(dev_idx, |dev| {
// This is safe because primary context can't be dropped
- let ctx_ptr = &dev.primary_context as *const _;
+ let ctx_ptr = &mut dev.primary_context as *mut _;
let flags_ptr =
(&unsafe { dev.primary_context.as_ref_unchecked() }.flags) as *const AtomicU32;
- (ctx_ptr, flags_ptr)
- })?;
- let is_active = context::CONTEXT_STACK
- .with(|stack| stack.borrow().last().map(|x| *x))
- .map(|current| current == ctx_ptr)
- .unwrap_or(false);
- let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed);
- unsafe { *flags = flags_value };
+ let is_active = context::CONTEXT_STACK
+ .with(|stack| stack.borrow().last().map(|x| *x))
+ .map(|current| current == ctx_ptr)
+ .unwrap_or(false);
+ let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed);
+ Ok::<_, l0::sys::ze_result_t>((is_active, flags_value))
+ })??;
unsafe { *active = if is_active { 1 } else { 0 } };
+ unsafe { *flags = flags_value };
Ok(())
}
-pub fn primary_ctx_retain(pctx: *mut *mut context::Context, dev: Index) -> Result<(), CUresult> {
- let ctx_ptr = with_exclusive(dev, |dev| &mut dev.primary_context as *mut _)?;
+pub fn primary_ctx_retain(
+ pctx: *mut *mut context::Context,
+ dev_idx: Index,
+) -> Result<(), CUresult> {
+ let ctx_ptr = GlobalState::lock_device(dev_idx, |dev| &mut dev.primary_context as *mut _)?;
unsafe { *pctx = ctx_ptr };
Ok(())
}
#[cfg(test)]
-mod tests {
+mod test {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs
index 562af37..ae9f6e3 100644
--- a/notcuda/src/impl/export_table.rs
+++ b/notcuda/src/impl/export_table.rs
@@ -4,7 +4,7 @@ use crate::{
cuda_impl,
};
-use super::{context, device, module, Decuda, Encuda};
+use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
@@ -110,17 +110,8 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [
VTableEntry { ptr: ptr::null() },
];
-unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult {
- cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda()
-}
-
-fn cudart_interface_fn1_impl(
- pctx: *mut *mut context::Context,
- dev: device::Index,
-) -> Result<(), CUresult> {
- let ctx_ptr = device::with_exclusive(dev, |d| &mut d.primary_context as *mut context::Context)?;
- unsafe { *pctx = ctx_ptr };
- Ok(())
+unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
+ super::unimplemented()
}
/*
@@ -200,7 +191,7 @@ unsafe extern "C" fn get_module_from_cubin(
ptr1: *mut c_void,
ptr2: *mut c_void,
) -> CUresult {
- // Not sure what those twoparameters are actually used for,
+ // Not sure what those two parameters are actually used for,
// they are somehow involved in __cudaRegisterHostVar
if ptr1 != ptr::null_mut() || ptr2 != ptr::null_mut() {
return CUresult::CUDA_ERROR_NOT_SUPPORTED;
@@ -234,10 +225,13 @@ unsafe extern "C" fn get_module_from_cubin(
},
Err(_) => continue,
};
- let module = module::ModuleData::compile_spirv(kernel_text_string);
+ let module = module::SpirvModule::new(kernel_text_string);
match module {
Ok(module) => {
- *result = Box::into_raw(Box::new(module));
+ match module::load_data_impl(result, module) {
+ Ok(()) => {}
+ Err(err) => return err,
+ }
return CUresult::CUDA_SUCCESS;
}
Err(_) => continue,
@@ -309,7 +303,7 @@ unsafe extern "C" fn context_local_storage_ctor(
}
fn context_local_storage_ctor_impl(
- mut cu_ctx: *mut context::Context,
+ cu_ctx: *mut context::Context,
mgr: *mut cuda_impl::rt::ContextStateManager,
ctx_state: *mut cuda_impl::rt::ContextState,
dtor_cb: Option<
@@ -320,26 +314,11 @@ fn context_local_storage_ctor_impl(
),
>,
) -> Result<(), CUresult> {
- if cu_ctx == ptr::null_mut() {
- context::get_current(&mut cu_ctx)?;
- }
- if cu_ctx == ptr::null_mut() {
- return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
- }
- unsafe { &*cu_ctx }
- .as_ref()
- .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
- .and_then(|ctx| {
- ctx.mutable
- .try_lock()
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
- .map(|mut mutable| {
- mutable.cuda_manager = mgr;
- mutable.cuda_state = ctx_state;
- mutable.cuda_dtor_cb = dtor_cb;
- })
- })?;
- Ok(())
+ lock_context(cu_ctx, |ctx: &mut ContextData| {
+ ctx.cuda_manager = mgr;
+ ctx.cuda_state = ctx_state;
+ ctx.cuda_dtor_cb = dtor_cb;
+ })
}
// some kind of dtor
@@ -357,24 +336,10 @@ unsafe extern "C" fn context_local_storage_get_state(
fn context_local_storage_get_state_impl(
ctx_state: *mut *mut cuda_impl::rt::ContextState,
- mut cu_ctx: *mut context::Context,
+ cu_ctx: *mut context::Context,
_: *mut cuda_impl::rt::ContextStateManager,
) -> Result<(), CUresult> {
- if cu_ctx == ptr::null_mut() {
- context::get_current(&mut cu_ctx)?;
- }
- if cu_ctx == ptr::null_mut() {
- return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
- }
- let cuda_state = unsafe { &*cu_ctx }
- .as_ref()
- .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
- .and_then(|ctx| {
- ctx.mutable
- .try_lock()
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
- .map(|mutable| mutable.cuda_state)
- })?;
+ let cuda_state = lock_context(cu_ctx, |ctx: &mut ContextData| ctx.cuda_state)?;
if cuda_state == ptr::null_mut() {
Err(CUresult::CUDA_ERROR_INVALID_VALUE)
} else {
@@ -382,3 +347,17 @@ fn context_local_storage_get_state_impl(
Ok(())
}
}
+
+fn lock_context<T>(
+ cu_ctx: *mut context::Context,
+ fn_impl: impl FnOnce(&mut ContextData) -> T,
+) -> Result<T, CUresult> {
+ if cu_ctx == ptr::null_mut() {
+ GlobalState::lock_current_context(fn_impl)
+ } else {
+ GlobalState::lock(|_| {
+ let ctx = unsafe { &mut *cu_ctx }.as_result_mut()?;
+ Ok(fn_impl(ctx))
+ })?
+ }
+}
diff --git a/notcuda/src/impl/function.rs b/notcuda/src/impl/function.rs
index 0ab3bea..394f806 100644
--- a/notcuda/src/impl/function.rs
+++ b/notcuda/src/impl/function.rs
@@ -1,11 +1,28 @@
use ::std::os::raw::{c_uint, c_void};
use std::ptr;
-use super::{device, stream::Stream, CUresult};
+use super::{CUresult, GlobalState, HasLivenessCookie, LiveCheck, stream::Stream};
-pub struct Function {
+pub type Function = LiveCheck<FunctionData>;
+
+impl HasLivenessCookie for FunctionData {
+ #[cfg(target_pointer_width = "64")]
+ const COOKIE: usize = 0x5e2ab14d5840678e;
+
+ #[cfg(target_pointer_width = "32")]
+ const COOKIE: usize = 0x33e6a1e6;
+
+ const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
+
+ fn try_drop(&mut self) -> Result<(), CUresult> {
+ Ok(())
+ }
+}
+
+pub struct FunctionData {
pub base: l0::Kernel<'static>,
pub arg_size: Vec<usize>,
+ pub use_shared_mem: bool,
}
pub fn launch_kernel(
@@ -17,36 +34,43 @@ pub fn launch_kernel(
block_dim_y: c_uint,
block_dim_z: c_uint,
shared_mem_bytes: c_uint,
- strean: *mut Stream,
+ hstream: *mut Stream,
kernel_params: *mut *mut c_void,
extra: *mut *mut c_void,
) -> Result<(), CUresult> {
if f == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- if shared_mem_bytes != 0 || strean != ptr::null_mut() || extra != ptr::null_mut() {
+ if extra != ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED);
}
- let func = unsafe { &*f };
- for (i, arg_size) in func.arg_size.iter().copied().enumerate() {
- unsafe {
- func.base
- .set_arg_raw(i as u32, arg_size, *kernel_params.add(i))?
- };
- }
- unsafe { &*f }
- .base
- .set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
- device::with_current_exclusive(|dev| {
- let mut cmd_list = l0::CommandList::new(&mut dev.l0_context, &dev.base)?;
+ GlobalState::lock_stream(hstream, |stream| {
+ let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?;
+ for (i, arg_size) in func.arg_size.iter().enumerate() {
+ unsafe {
+ func.base
+ .set_arg_raw(i as u32, *arg_size, *kernel_params.add(i))?
+ };
+ }
+ if func.use_shared_mem {
+ unsafe {
+ func.base.set_arg_raw(
+ func.arg_size.len() as u32,
+ shared_mem_bytes as usize,
+ ptr::null(),
+ )?
+ };
+ }
+ func.base
+ .set_group_size(block_dim_x, block_dim_y, block_dim_z)?;
+ let mut cmd_list = stream.command_list()?;
cmd_list.append_launch_kernel(
- &unsafe { &*f }.base,
+ &mut func.base,
&[grid_dim_x, grid_dim_y, grid_dim_z],
None,
&mut [],
)?;
- dev.default_queue.execute(cmd_list)?;
- l0::Result::Ok(())
- })??;
- Ok(())
+ stream.queue.execute(cmd_list)?;
+ Ok(())
+ })?
}
diff --git a/notcuda/src/impl/memory.rs b/notcuda/src/impl/memory.rs
index 439b26f..62dc1cc 100644
--- a/notcuda/src/impl/memory.rs
+++ b/notcuda/src/impl/memory.rs
@@ -1,57 +1,34 @@
-use super::CUresult;
+use super::{stream, CUresult, GlobalState};
use std::ffi::c_void;
-pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
- let alloc_result = super::device::with_current_exclusive(|dev| unsafe {
- dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0)
- });
- match alloc_result {
- Ok(Ok(alloc)) => {
- unsafe { *dptr = alloc };
- CUresult::CUDA_SUCCESS
- }
- Ok(Err(e)) => e.into(),
- Err(e) => e,
- }
+pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
+ let ptr = GlobalState::lock_current_context(|ctx| {
+ let dev = unsafe { &mut *ctx.device };
+ Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
+ })??;
+ unsafe { *dptr = ptr };
+ Ok(())
}
-pub fn copy_v2(
- dst: *mut c_void,
- src: *const c_void,
- bytesize: usize,
-) -> Result<Result<(), l0::sys::ze_result_t>, CUresult> {
- super::device::with_current_exclusive(|dev| unsafe {
- memcpy_impl(
- &mut dev.l0_context,
- dst,
- src,
- bytesize,
- &dev.base,
- &mut dev.default_queue,
- )
- })
+pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
+ GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
+ let mut cmd_list = stream.command_list()?;
+ unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
+ stream.queue.execute(cmd_list)?;
+ Ok::<_, CUresult>(())
+ })?
}
-unsafe fn memcpy_impl(
- ctx: &mut l0::Context,
- dst: *mut c_void,
- src: *const c_void,
- bytes_count: usize,
- dev: &l0::Device,
- queue: &mut l0::CommandQueue,
-) -> l0::Result<()> {
- let mut cmd_list = l0::CommandList::new(ctx, &dev)?;
- cmd_list.append_memory_copy_unsafe(dst, src, bytes_count, None, &mut [])?;
- queue.execute(cmd_list)?;
- Ok(())
-}
-
-pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> {
- Ok(())
+pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
+ GlobalState::lock_current_context(|ctx| {
+ let dev = unsafe { &mut *ctx.device };
+ Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
+ })
+ .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
}
#[cfg(test)]
-mod tests {
+mod test {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
use std::ptr;
@@ -82,4 +59,20 @@ mod tests {
assert_ne!(mem, ptr::null_mut());
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
}
+
+ cuda_driver_test!(free_without_ctx);
+
+ fn free_without_ctx<T: CudaDriverFns>() {
+ assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
+ let mut ctx = ptr::null_mut();
+ assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
+ let mut mem = ptr::null_mut();
+ assert_eq!(
+ T::cuMemAlloc_v2(&mut mem, std::mem::size_of::<usize>()),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_ne!(mem, ptr::null_mut());
+ assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
+ assert_eq!(T::cuMemFree_v2(mem), CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
}
diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs
index 5a72ce4..770a32b 100644
--- a/notcuda/src/impl/mod.rs
+++ b/notcuda/src/impl/mod.rs
@@ -1,5 +1,15 @@
-use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
-use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex};
+use crate::{
+ cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st},
+ r#impl::device::Device,
+};
+use std::{
+ ffi::c_void,
+ mem::{self, ManuallyDrop},
+ os::raw::c_int,
+ ptr,
+ sync::Mutex,
+ sync::TryLockError,
+};
#[cfg(test)]
#[macro_use]
@@ -7,9 +17,9 @@ pub mod test;
pub mod context;
pub mod device;
pub mod export_table;
+pub mod function;
pub mod memory;
pub mod module;
-pub mod function;
pub mod stream;
#[cfg(debug_assertions)]
@@ -22,8 +32,11 @@ pub fn unimplemented() -> CUresult {
CUresult::CUDA_ERROR_NOT_SUPPORTED
}
-pub trait HasLivenessCookie {
+pub trait HasLivenessCookie: Sized {
const COOKIE: usize;
+ const LIVENESS_FAIL: CUresult;
+
+ fn try_drop(&mut self) -> Result<(), CUresult>;
}
// This struct is a best-effort check if wrapped value has been dropped,
@@ -42,34 +55,55 @@ impl<T: HasLivenessCookie> LiveCheck<T> {
}
}
+ fn destroy_impl(this: *mut Self) -> Result<(), CUresult> {
+ let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) });
+ ctx_box.try_drop()?;
+ unsafe { ManuallyDrop::drop(&mut ctx_box) };
+ Ok(())
+ }
+
+ unsafe fn ptr_from_inner(this: *mut T) -> *mut Self {
+ let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>());
+ outer_ptr as *mut Self
+ }
+
pub unsafe fn as_ref_unchecked(&self) -> &T {
&self.data
}
- pub fn as_ref(&self) -> Option<&T> {
+ pub fn as_option_mut(&mut self) -> Option<&mut T> {
if self.cookie == T::COOKIE {
- Some(&self.data)
+ Some(&mut self.data)
} else {
None
}
}
- pub fn as_mut(&mut self) -> Option<&mut T> {
+ pub fn as_result(&self) -> Result<&T, CUresult> {
if self.cookie == T::COOKIE {
- Some(&mut self.data)
+ Ok(&self.data)
} else {
- None
+ Err(T::LIVENESS_FAIL)
+ }
+ }
+
+ pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> {
+ if self.cookie == T::COOKIE {
+ Ok(&mut self.data)
+ } else {
+ Err(T::LIVENESS_FAIL)
}
}
#[must_use]
- pub fn try_drop(&mut self) -> bool {
+ pub fn try_drop(&mut self) -> Result<(), CUresult> {
if self.cookie == T::COOKIE {
self.cookie = 0;
+ self.data.try_drop()?;
unsafe { ManuallyDrop::drop(&mut self.data) };
- return true;
+ return Ok(());
}
- false
+ Err(T::LIVENESS_FAIL)
}
}
@@ -121,6 +155,12 @@ impl From<l0::sys::ze_result_t> for CUresult {
}
}
+impl<T> From<TryLockError<T>> for CUresult {
+ fn from(_: TryLockError<T>) -> Self {
+ CUresult::CUDA_ERROR_ILLEGAL_STATE
+ }
+}
+
pub trait Encuda {
type To: Sized;
fn encuda(self: Self) -> Self::To;
@@ -157,58 +197,103 @@ impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1,
}
}
-pub enum Error {
- L0(l0::sys::ze_result_t),
- Cuda(CUresult),
-}
-
-impl Encuda for Error {
- type To = CUresult;
- fn encuda(self: Self) -> Self::To {
- match self {
- Error::L0(e) => e.into(),
- Error::Cuda(e) => e,
- }
- }
-}
-
lazy_static! {
static ref GLOBAL_STATE: Mutex<Option<GlobalState>> = Mutex::new(None);
}
struct GlobalState {
- driver: l0::Driver,
+ devices: Vec<Device>,
}
unsafe impl Send for GlobalState {}
+impl GlobalState {
+ fn lock<T>(f: impl FnOnce(&mut GlobalState) -> T) -> Result<T, CUresult> {
+ let mut mutex = GLOBAL_STATE
+ .lock()
+ .unwrap_or_else(|poison| poison.into_inner());
+ let global_state = mutex.as_mut().ok_or(CUresult::CUDA_ERROR_ILLEGAL_STATE)?;
+ Ok(f(global_state))
+ }
+
+ fn lock_device<T>(
+ device::Index(dev_idx): device::Index,
+ f: impl FnOnce(&'static mut device::Device) -> T,
+ ) -> Result<T, CUresult> {
+ if dev_idx < 0 {
+ return Err(CUresult::CUDA_ERROR_INVALID_DEVICE);
+ }
+ Self::lock(|global_state| {
+ if dev_idx >= global_state.devices.len() as c_int {
+ Err(CUresult::CUDA_ERROR_INVALID_DEVICE)
+ } else {
+ Ok(f(unsafe {
+ transmute_lifetime_mut(&mut global_state.devices[dev_idx as usize])
+ }))
+ }
+ })?
+ }
+
+ fn lock_current_context<F: FnOnce(&mut context::ContextData) -> R, R>(
+ f: F,
+ ) -> Result<R, CUresult> {
+ Self::lock_current_context_unchecked(|ctx| Ok(f(ctx.as_result_mut()?)))?
+ }
+
+ fn lock_current_context_unchecked<F: FnOnce(&mut context::Context) -> R, R>(
+ f: F,
+ ) -> Result<R, CUresult> {
+ context::CONTEXT_STACK.with(|stack| {
+ stack
+ .borrow_mut()
+ .last_mut()
+ .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
+ .map(|ctx| GlobalState::lock(|_| f(unsafe { &mut **ctx })))?
+ })
+ }
+
+ fn lock_stream<T>(
+ stream: *mut stream::Stream,
+ f: impl FnOnce(&mut stream::StreamData) -> T,
+ ) -> Result<T, CUresult> {
+ if stream == ptr::null_mut()
+ || stream == stream::CU_STREAM_LEGACY
+ || stream == stream::CU_STREAM_PER_THREAD
+ {
+ Self::lock_current_context(|ctx| Ok(f(&mut ctx.default_stream)))?
+ } else {
+ Self::lock(|_| {
+ let stream = unsafe { &mut *stream }.as_result_mut()?;
+ Ok(f(stream))
+ })?
+ }
+ }
+}
+
// TODO: implement
fn is_intel_gpu_driver(_: &l0::Driver) -> bool {
true
}
-pub fn init() -> l0::Result<()> {
+pub fn init() -> Result<(), CUresult> {
let mut global_state = GLOBAL_STATE
.lock()
- .map_err(|_| l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?;
+ .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?;
if global_state.is_some() {
return Ok(());
}
l0::init()?;
let drivers = l0::Driver::get()?;
- let driver = match drivers.into_iter().find(is_intel_gpu_driver) {
- None => return Err(l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN),
- Some(driver) => {
- device::init(&driver)?;
- driver
- }
+ let devices = match drivers.into_iter().find(is_intel_gpu_driver) {
+ None => return Err(CUresult::CUDA_ERROR_UNKNOWN),
+ Some(driver) => device::init(&driver)?,
};
- *global_state = Some(GlobalState { driver });
+ *global_state = Some(GlobalState { devices });
drop(global_state);
Ok(())
}
-unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T {
+unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T {
mem::transmute(t)
}
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs
index 35436c3..4422107 100644
--- a/notcuda/src/impl/module.rs
+++ b/notcuda/src/impl/module.rs
@@ -1,79 +1,90 @@
use std::{
- collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex,
+ collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem,
+ os::raw::c_char, ptr, slice,
};
-use super::{function::Function, transmute_lifetime, CUresult};
+use super::{
+ device, function::Function, function::FunctionData, CUresult, GlobalState, HasLivenessCookie,
+ LiveCheck,
+};
use ptx;
-pub type Module = Mutex<ModuleData>;
+pub type Module = LiveCheck<ModuleData>;
+
+impl HasLivenessCookie for ModuleData {
+ #[cfg(target_pointer_width = "64")]
+ const COOKIE: usize = 0xf1313bd46505f98a;
+
+ #[cfg(target_pointer_width = "32")]
+ const COOKIE: usize = 0xbdbe3f15;
+
+ const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
+
+ fn try_drop(&mut self) -> Result<(), CUresult> {
+ Ok(())
+ }
+}
pub struct ModuleData {
- base: l0::Module,
- arg_lens: HashMap<CString, Vec<usize>>,
+ pub spirv: SpirvModule,
+ // This should be a Vec<>, but I'm feeling lazy
+ pub device_binaries: HashMap<device::Index, CompiledModule>,
}
-pub enum ModuleCompileError<'a> {
- Parse(
- Vec<ptx::ast::PtxError>,
- Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>,
- ),
- Compile(ptx::TranslateError),
- L0(l0::sys::ze_result_t),
- CUDA(CUresult),
+pub struct SpirvModule {
+ pub binaries: Vec<u32>,
+ pub kernel_info: HashMap<String, ptx::KernelInfo>,
+ pub should_link_ptx_impl: Option<&'static [u8]>,
+ pub build_options: CString,
}
-impl<'a> ModuleCompileError<'a> {
- pub fn get_build_log(&self) {}
+pub struct CompiledModule {
+ pub base: l0::Module,
+ pub kernels: HashMap<CString, Box<Function>>,
}
-impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> {
- fn from(err: ptx::TranslateError) -> Self {
- ModuleCompileError::Compile(err)
+impl<L, T, E> From<ptx::ParseError<L, T, E>> for CUresult {
+ fn from(_: ptx::ParseError<L, T, E>) -> Self {
+ CUresult::CUDA_ERROR_INVALID_PTX
}
}
-impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> {
- fn from(err: l0::sys::ze_result_t) -> Self {
- ModuleCompileError::L0(err)
+impl From<ptx::TranslateError> for CUresult {
+ fn from(_: ptx::TranslateError) -> Self {
+ CUresult::CUDA_ERROR_INVALID_PTX
}
}
-impl<'a> From<CUresult> for ModuleCompileError<'a> {
- fn from(err: CUresult) -> Self {
- ModuleCompileError::CUDA(err)
+impl SpirvModule {
+ pub fn new_raw<'a>(text: *const c_char) -> Result<Self, CUresult> {
+ let u8_text = unsafe { CStr::from_ptr(text) };
+ let ptx_text = u8_text
+ .to_str()
+ .map_err(|_| CUresult::CUDA_ERROR_INVALID_PTX)?;
+ Self::new(ptx_text)
}
-}
-impl ModuleData {
- pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> {
+ pub fn new<'a>(ptx_text: &str) -> Result<Self, CUresult> {
let mut errors = Vec::new();
- let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text);
- let ast = match ast {
- Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))),
- Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
- Ok(ast) => ast,
- };
- let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?;
+ let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?;
+ let spirv_module = ptx::to_spirv_module(ast)?;
+ Ok(SpirvModule {
+ binaries: spirv_module.assemble(),
+ kernel_info: spirv_module.kernel_info,
+ should_link_ptx_impl: spirv_module.should_link_ptx_impl,
+ build_options: spirv_module.build_options,
+ })
+ }
+
+ pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> {
let byte_il = unsafe {
- slice::from_raw_parts::<u8>(
- spirv.as_ptr() as *const _,
- spirv.len() * mem::size_of::<u32>(),
+ slice::from_raw_parts(
+ self.binaries.as_ptr() as *const u8,
+ self.binaries.len() * mem::size_of::<u32>(),
)
};
- let module = super::device::with_current_exclusive(|dev| {
- l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
- });
- match module {
- Ok((Ok(module), _)) => Ok(Mutex::new(Self {
- base: module,
- arg_lens: all_arg_lens
- .into_iter()
- .map(|(k, v)| (CString::new(k).unwrap(), v))
- .collect(),
- })),
- Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
- Err(err) => Err(ModuleCompileError::from(err)),
- }
+ let l0_module = l0::Module::build_spirv(ctx, dev, byte_il, None).0?;
+ Ok(l0_module)
}
}
@@ -85,34 +96,75 @@ pub fn get_function(
if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() {
return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
}
- let name = unsafe { CStr::from_ptr(name) };
- let (mut kernel, args_len) = unsafe { &*hmod }
- .try_lock()
- .map(|module| {
- Result::<_, CUresult>::Ok((
- l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?,
- module
- .arg_lens
- .get(name)
- .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?
- .clone(),
- ))
- })
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??;
- kernel.set_indirect_access(
- l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
- | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
- | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED,
- )?;
- unsafe {
- *hfunc = Box::into_raw(Box::new(Function {
- base: kernel,
- arg_size: args_len,
- }))
- };
+ let name = unsafe { CStr::from_ptr(name) }.to_owned();
+ let function: *mut Function = GlobalState::lock_current_context(|ctx| {
+ let module = unsafe { &mut *hmod }.as_result_mut()?;
+ let device = unsafe { &mut *ctx.device };
+ let compiled_module = match module.device_binaries.entry(device.index) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let new_module = CompiledModule {
+ base: module.spirv.compile(&mut device.l0_context, &device.base)?,
+ kernels: HashMap::new(),
+ };
+ entry.insert(new_module)
+ }
+ };
+ //let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) };
+ let kernel = match compiled_module.kernels.entry(name) {
+ hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(),
+ hash_map::Entry::Vacant(entry) => {
+ let kernel_info = module
+ .spirv
+ .kernel_info
+ .get(unsafe {
+ std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
+ })
+ .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
+ let kernel =
+ l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
+ entry.insert(Box::new(Function::new(FunctionData {
+ base: kernel,
+ arg_size: kernel_info.arguments_sizes.clone(),
+ use_shared_mem: kernel_info.uses_shared_mem,
+ })))
+ }
+ };
+ Ok::<_, CUresult>(kernel as *mut _)
+ })??;
+ unsafe { *hfunc = function };
Ok(())
}
-pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> {
+pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<(), CUresult> {
+ let spirv_data = SpirvModule::new_raw(image as *const _)?;
+ load_data_impl(pmod, spirv_data)
+}
+
+pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> {
+ let module = GlobalState::lock_current_context(|ctx| {
+ let device = unsafe { &mut *ctx.device };
+ let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?;
+ let mut device_binaries = HashMap::new();
+ let compiled_module = CompiledModule {
+ base: l0_module,
+ kernels: HashMap::new(),
+ };
+ device_binaries.insert(device.index, compiled_module);
+ let module_data = ModuleData {
+ spirv: spirv_data,
+ device_binaries,
+ };
+ Ok::<_, CUresult>(module_data)
+ })??;
+ let module_ptr = Box::into_raw(Box::new(Module::new(module)));
+ unsafe { *pmod = module_ptr };
Ok(())
}
+
+pub(crate) fn unload(module: *mut Module) -> Result<(), CUresult> {
+ if module == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ GlobalState::lock(|_| Module::destroy_impl(module))?
+}
diff --git a/notcuda/src/impl/stream.rs b/notcuda/src/impl/stream.rs
index 1844677..e212dfc 100644
--- a/notcuda/src/impl/stream.rs
+++ b/notcuda/src/impl/stream.rs
@@ -1,36 +1,114 @@
-use std::cell::RefCell;
+use super::{
+ context::{Context, ContextData},
+ CUresult, GlobalState,
+};
+use std::{mem, ptr};
-use device::Device;
+use super::{HasLivenessCookie, LiveCheck};
-use super::device;
+pub type Stream = LiveCheck<StreamData>;
-pub struct Stream {
- dev: *mut Device,
+pub const CU_STREAM_LEGACY: *mut Stream = 1 as *mut _;
+pub const CU_STREAM_PER_THREAD: *mut Stream = 2 as *mut _;
+
+impl HasLivenessCookie for StreamData {
+ #[cfg(target_pointer_width = "64")]
+ const COOKIE: usize = 0x512097354de18d35;
+
+ #[cfg(target_pointer_width = "32")]
+ const COOKIE: usize = 0x77d5cc0b;
+
+ const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE;
+
+ fn try_drop(&mut self) -> Result<(), CUresult> {
+ if self.context != ptr::null_mut() {
+ let context = unsafe { &mut *self.context };
+ if !context.streams.remove(&(self as *mut _)) {
+ return Err(CUresult::CUDA_ERROR_UNKNOWN);
+ }
+ }
+ Ok(())
+ }
+}
+
+pub struct StreamData {
+ pub context: *mut ContextData,
+ pub queue: l0::CommandQueue,
}
-pub struct DefaultStream {
- streams: Vec<Option<Stream>>,
+impl StreamData {
+ pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> {
+ Ok(StreamData {
+ context: ptr::null_mut(),
+ queue: l0::CommandQueue::new(ctx, dev)?,
+ })
+ }
+ pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> {
+ let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context;
+ let l0_dev = &unsafe { &*ctx.device }.base;
+ Ok(StreamData {
+ context: ctx as *mut _,
+ queue: l0::CommandQueue::new(l0_ctx, l0_dev)?,
+ })
+ }
+
+ pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> {
+ let ctx = unsafe { &mut *self.context };
+ let dev = unsafe { &mut *ctx.device };
+ l0::CommandList::new(&mut dev.l0_context, &dev.base)
+ }
}
-impl DefaultStream {
- fn new() -> Self {
- DefaultStream {
- streams: Vec::new(),
+impl Drop for StreamData {
+ fn drop(&mut self) {
+ if self.context == ptr::null_mut() {
+ return;
}
+ unsafe { (&mut *self.context).streams.remove(&(&mut *self as *mut _)) };
}
}
-thread_local! {
- pub static DEFAULT_STREAM: RefCell<DefaultStream> = RefCell::new(DefaultStream::new());
+pub(crate) fn get_ctx(hstream: *mut Stream, pctx: *mut *mut Context) -> Result<(), CUresult> {
+ if pctx == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ let ctx_ptr = GlobalState::lock_stream(hstream, |stream| stream.context)?;
+ if ctx_ptr == ptr::null_mut() {
+ return Err(CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED);
+ }
+ unsafe { *pctx = Context::ptr_from_inner(ctx_ptr) };
+ Ok(())
+}
+
+pub(crate) fn create(phstream: *mut *mut Stream, _flags: u32) -> Result<(), CUresult> {
+ let stream_ptr = GlobalState::lock_current_context(|ctx| {
+ let mut stream_box = Box::new(Stream::new(StreamData::new(ctx)?));
+ let stream_ptr = stream_box.as_mut().as_option_mut().unwrap() as *mut _;
+ if !ctx.streams.insert(stream_ptr) {
+ return Err(CUresult::CUDA_ERROR_UNKNOWN);
+ }
+ mem::forget(stream_box);
+ Ok::<_, CUresult>(stream_ptr)
+ })??;
+ unsafe { *phstream = Stream::ptr_from_inner(stream_ptr) };
+ Ok(())
+}
+
+pub(crate) fn destroy_v2(pstream: *mut Stream) -> Result<(), CUresult> {
+ if pstream == ptr::null_mut() || pstream == CU_STREAM_LEGACY || pstream == CU_STREAM_PER_THREAD
+ {
+ return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
+ GlobalState::lock(|_| Stream::destroy_impl(pstream))?
}
#[cfg(test)]
-mod tests {
+mod test {
use crate::cuda::CUstream;
use super::super::test::CudaDriverFns;
use super::super::CUresult;
- use std::ptr;
+ use std::{ptr, thread};
const CU_STREAM_LEGACY: CUstream = 1 as *mut _;
const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _;
@@ -65,5 +143,100 @@ mod tests {
CUresult::CUDA_SUCCESS
);
assert_eq!(ctx2, stream_ctx2);
+ // Cleanup
+ assert_eq!(T::cuCtxDestroy_v2(ctx1), CUresult::CUDA_SUCCESS);
+ assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS);
+ }
+
+ cuda_driver_test!(stream_context_destroyed);
+
+ fn stream_context_destroyed<T: CudaDriverFns>() {
+ assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
+ let mut ctx = ptr::null_mut();
+ assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
+ let mut stream = ptr::null_mut();
+ assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
+ let mut stream_ctx1 = ptr::null_mut();
+ assert_eq!(
+ T::cuStreamGetCtx(stream, &mut stream_ctx1),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!(stream_ctx1, ctx);
+ assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
+ let mut stream_ctx2 = ptr::null_mut();
+ // When a context gets destroyed, its streams are also destroyed
+ let cuda_result = T::cuStreamGetCtx(stream, &mut stream_ctx2);
+ assert!(
+ cuda_result == CUresult::CUDA_ERROR_INVALID_HANDLE
+ || cuda_result == CUresult::CUDA_ERROR_INVALID_CONTEXT
+ || cuda_result == CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED
+ );
+ assert_eq!(
+ T::cuStreamDestroy_v2(stream),
+ CUresult::CUDA_ERROR_INVALID_HANDLE
+ );
+ // Check if creating another context is possible
+ let mut ctx2 = ptr::null_mut();
+ assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS);
+ // Cleanup
+ assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS);
+ }
+
+ cuda_driver_test!(stream_moves_context_to_another_thread);
+
+ fn stream_moves_context_to_another_thread<T: CudaDriverFns>() {
+ assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
+ let mut ctx = ptr::null_mut();
+ assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
+ let mut stream = ptr::null_mut();
+ assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
+ let mut stream_ctx1 = ptr::null_mut();
+ assert_eq!(
+ T::cuStreamGetCtx(stream, &mut stream_ctx1),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!(stream_ctx1, ctx);
+ let stream_ptr = stream as usize;
+ let stream_ctx_on_thread = thread::spawn(move || {
+ let mut stream_ctx2 = ptr::null_mut();
+ assert_eq!(
+ T::cuStreamGetCtx(stream_ptr as *mut _, &mut stream_ctx2),
+ CUresult::CUDA_SUCCESS
+ );
+ stream_ctx2 as usize
+ })
+ .join()
+ .unwrap();
+ assert_eq!(stream_ctx1, stream_ctx_on_thread as *mut _);
+ // Cleanup
+ assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS);
+ assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
+ }
+
+ cuda_driver_test!(can_destroy_stream);
+
+ fn can_destroy_stream<T: CudaDriverFns>() {
+ assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
+ let mut ctx = ptr::null_mut();
+ assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
+ let mut stream = ptr::null_mut();
+ assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS);
+ assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS);
+ // Cleanup
+ assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
+ }
+
+ cuda_driver_test!(cant_destroy_default_stream);
+
+ fn cant_destroy_default_stream<T: CudaDriverFns>() {
+ assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
+ let mut ctx = ptr::null_mut();
+ assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
+ assert_ne!(
+ T::cuStreamDestroy_v2(super::CU_STREAM_LEGACY as *mut _),
+ CUresult::CUDA_SUCCESS
+ );
+ // Cleanup
+ assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
}
}
diff --git a/notcuda/src/impl/test.rs b/notcuda/src/impl/test.rs
index dbd2eff..b6ed926 100644
--- a/notcuda/src/impl/test.rs
+++ b/notcuda/src/impl/test.rs
@@ -1,8 +1,12 @@
#![allow(non_snake_case)]
-use crate::{cuda::CUstream, r#impl as notcuda};
-use crate::r#impl::CUresult;
-use crate::{cuda::CUuuid, r#impl::Encuda};
+use crate::cuda as notcuda;
+use crate::cuda::CUstream;
+use crate::cuda::CUuuid;
+use crate::{
+ cuda::{CUdevice, CUdeviceptr},
+ r#impl::CUresult,
+};
use ::std::{
ffi::c_void,
os::raw::{c_int, c_uint},
@@ -37,48 +41,63 @@ pub trait CudaDriverFns {
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult;
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult;
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult;
+ fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult;
+ fn cuMemFree_v2(mem: *mut c_void) -> CUresult;
+ fn cuStreamDestroy_v2(stream: CUstream) -> CUresult;
}
pub struct NotCuda();
impl CudaDriverFns for NotCuda {
fn cuInit(_flags: c_uint) -> CUresult {
- crate::cuda::cuInit(_flags as _)
+ notcuda::cuInit(_flags as _)
}
fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult {
- notcuda::context::create_v2(pctx as *mut _, flags, notcuda::device::Index(dev)).encuda()
+ notcuda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev))
}
fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult {
- notcuda::context::destroy_v2(ctx as *mut _)
+ notcuda::cuCtxDestroy_v2(ctx as *mut _)
}
fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult {
- notcuda::context::pop_current_v2(pctx as *mut _)
+ notcuda::cuCtxPopCurrent_v2(pctx as *mut _)
}
fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult {
- notcuda::context::get_api_version(ctx as *mut _, version)
+ notcuda::cuCtxGetApiVersion(ctx as *mut _, version)
}
fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult {
- notcuda::context::get_current(pctx as *mut _).encuda()
+ notcuda::cuCtxGetCurrent(pctx as *mut _)
}
fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
- notcuda::memory::alloc_v2(dptr as *mut _, bytesize)
+ notcuda::cuMemAlloc_v2(dptr as *mut _, bytesize)
}
fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult {
- notcuda::device::get_uuid(uuid, notcuda::device::Index(dev)).encuda()
+ notcuda::cuDeviceGetUuid(uuid, CUdevice(dev))
}
fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult {
- notcuda::device::primary_ctx_get_state(notcuda::device::Index(dev), flags, active).encuda()
+ notcuda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active)
}
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
- crate::cuda::cuStreamGetCtx(hStream, pctx as _)
+ notcuda::cuStreamGetCtx(hStream, pctx as _)
+ }
+
+ fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
+ notcuda::cuStreamCreate(stream, flags)
+ }
+
+ fn cuMemFree_v2(dptr: *mut c_void) -> CUresult {
+ notcuda::cuMemFree_v2(CUdeviceptr(dptr as _))
+ }
+
+ fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
+ notcuda::cuStreamDestroy_v2(stream)
}
}
@@ -123,4 +142,16 @@ impl CudaDriverFns for Cuda {
fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult {
unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) }
}
+
+ fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult {
+ unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) }
+ }
+
+ fn cuMemFree_v2(mem: *mut c_void) -> CUresult {
+ unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) }
+ }
+
+ fn cuStreamDestroy_v2(stream: CUstream) -> CUresult {
+ unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) }
+ }
}
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 1aac8ab..591428f 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -34,8 +34,9 @@ pub use crate::ptx::ModuleParser;
pub use lalrpop_util::lexer::Token;
pub use lalrpop_util::ParseError;
pub use rspirv::dr::Error as SpirvError;
-pub use translate::TranslateError as TranslateError;
-pub use translate::to_spirv;
+pub use translate::to_spirv_module;
+pub use translate::KernelInfo;
+pub use translate::TranslateError;
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
x.into_iter().filter_map(|x| x).collect()
diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs
index 0339141..0785f3e 100644
--- a/ptx/src/test/mod.rs
+++ b/ptx/src/test/mod.rs
@@ -12,7 +12,7 @@ fn parse_and_assert(s: &str) {
fn compile_and_assert(s: &str) -> Result<(), TranslateError> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
- crate::to_spirv(ast)?;
+ crate::to_spirv_module(ast)?;
Ok(())
}
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index c0e15f2..3d0f476 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,7 +1,7 @@
use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
-use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem};
+use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
@@ -450,6 +450,11 @@ pub struct Module {
pub should_link_ptx_impl: Option<&'static [u8]>,
pub build_options: CString,
}
+impl Module {
+ pub fn assemble(&self) -> Vec<u32> {
+ self.spirv.assemble()
+ }
+}
pub struct KernelInfo {
pub arguments_sizes: Vec<usize>,
@@ -1046,8 +1051,12 @@ fn emit_function_header<'a>(
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
if let MethodName::Kernel(name) = func_decl.name {
- let args_lens = func_decl
- .input
+ let input_args = if !func_decl.uses_shared_mem {
+ func_decl.input.as_slice()
+ } else {
+ &func_decl.input[0..func_decl.input.len() - 1]
+ };
+ let args_lens = input_args
.iter()
.map(|param| param.v_type.size_of())
.collect();
@@ -1135,21 +1144,6 @@ fn emit_function_header<'a>(
Ok(())
}
-pub fn to_spirv<'a>(
- ast: ast::Module<'a>,
-) -> Result<(Option<&'static [u8]>, Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
- let module = to_spirv_module(ast)?;
- Ok((
- module.should_link_ptx_impl,
- module.spirv.assemble(),
- module
- .kernel_info
- .into_iter()
- .map(|(k, v)| (k, v.arguments_sizes))
- .collect(),
- ))
-}
-
fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::GenericPointer);
builder.capability(spirv::Capability::Linkage);