diff options
author | Andrzej Janik <[email protected]> | 2020-11-11 22:35:34 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-11-12 20:12:14 +0100 |
commit | a2e77fe961fb543370261e844d6cd79e0269e877 (patch) | |
tree | dea4c9c506c73941403de8ab46560bd569b4a7e7 | |
parent | 7c93997cc9b90886b6371ca3b93e21e7e6ae073d (diff) | |
download | ZLUDA-a2e77fe961fb543370261e844d6cd79e0269e877.tar.gz ZLUDA-a2e77fe961fb543370261e844d6cd79e0269e877.zip |
Refactor host code to use one big lock
-rw-r--r-- | level_zero/src/ze.rs | 12 | ||||
-rw-r--r-- | notcuda/build.rs | 27 | ||||
-rw-r--r-- | notcuda/src/cuda.rs | 25 | ||||
-rw-r--r-- | notcuda/src/impl/context.rs | 191 | ||||
-rw-r--r-- | notcuda/src/impl/device.rs | 283 | ||||
-rw-r--r-- | notcuda/src/impl/export_table.rs | 83 | ||||
-rw-r--r-- | notcuda/src/impl/function.rs | 66 | ||||
-rw-r--r-- | notcuda/src/impl/memory.rs | 83 | ||||
-rw-r--r-- | notcuda/src/impl/mod.rs | 161 | ||||
-rw-r--r-- | notcuda/src/impl/module.rs | 206 | ||||
-rw-r--r-- | notcuda/src/impl/stream.rs | 203 | ||||
-rw-r--r-- | notcuda/src/impl/test.rs | 57 | ||||
-rw-r--r-- | ptx/src/lib.rs | 5 | ||||
-rw-r--r-- | ptx/src/test/mod.rs | 2 | ||||
-rw-r--r-- | ptx/src/translate.rs | 30 |
15 files changed, 904 insertions, 530 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index f8a2c3b..4267682 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -173,6 +173,16 @@ impl Context { check!(sys::zeContextCreate(drv.0, &ctx_desc, &mut result));
Ok(Context(result))
}
+
+ pub unsafe fn mem_free(&mut self, ptr: *mut c_void) -> Result<()> {
+ check! {
+ sys::zeMemFree(
+ self.0,
+ ptr,
+ )
+ };
+ Ok(())
+ }
}
impl Drop for Context {
@@ -239,7 +249,7 @@ pub struct Module(sys::ze_module_handle_t); impl Module {
// HACK ALERT
- // We use OpenCL for now to do SPIR-V linking, because Level0
+ // We use OpenCL for now to do SPIR-V linking, because Level0
// does not allow linking. Don't let presence of zeModuleDynamicLink fool
// you, it's not currently possible to create non-compiled modules.
// zeModuleCreate always compiles (builds and links).
diff --git a/notcuda/build.rs b/notcuda/build.rs new file mode 100644 index 0000000..3b8999f --- /dev/null +++ b/notcuda/build.rs @@ -0,0 +1,27 @@ +// HACK ALERT
+// This buidl script has been copy-pasted from cl-sys to avoid CUDA libraries
+// overriding path to OpenCL
+
+ fn main() {
+ if cfg!(windows) {
+ let known_sdk = [
+ // E.g. "c:\Program Files (x86)\Intel\OpenCL SDK\lib\x86\"
+ ("INTELOCLSDKROOT", "x64", "x86"),
+ // E.g. "c:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\Win32\"
+ ("CUDA_PATH", "x64", "Win32"),
+ // E.g. "C:\Program Files (x86)\AMD APP SDK\3.0\lib\x86\"
+ ("AMDAPPSDKROOT", "x86_64", "x86"),
+ ];
+
+ for info in known_sdk.iter() {
+ if let Ok(sdk) = std::env::var(info.0) {
+ let mut path = std::path::PathBuf::from(sdk);
+ path.push("lib");
+ path.push(if cfg!(target_arch="x86_64") { info.1 } else { info.2 });
+ println!("cargo:rustc-link-search=native={}", path.display());
+ }
+ }
+
+ println!("cargo:rustc-link-search=native=C:\\Program Files (x86)\\OCL_SDK_Light\\lib\\x86_64");
+ }
+}
\ No newline at end of file diff --git a/notcuda/src/cuda.rs b/notcuda/src/cuda.rs index a18ebf9..335da4a 100644 --- a/notcuda/src/cuda.rs +++ b/notcuda/src/cuda.rs @@ -2210,12 +2210,12 @@ pub extern "C" fn cuDriverGetVersion(driverVersion: *mut ::std::os::raw::c_int) #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuDeviceGet(device: *mut CUdevice, ordinal: ::std::os::raw::c_int) -> CUresult { - r#impl::device::get(device.decuda(), ordinal) + r#impl::device::get(device.decuda(), ordinal).encuda() } #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuDeviceGetCount(count: *mut ::std::os::raw::c_int) -> CUresult { - r#impl::device::get_count(count) + r#impl::device::get_count(count).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2314,7 +2314,6 @@ pub extern "C" fn cuDevicePrimaryCtxReset(dev: CUdevice) -> CUresult { cuDevicePrimaryCtxReset_v2(dev) } - #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuDevicePrimaryCtxReset_v2(dev: CUdevice) -> CUresult { r#impl::unimplemented() @@ -2331,7 +2330,7 @@ pub extern "C" fn cuCtxCreate_v2( #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuCtxDestroy_v2(ctx: CUcontext) -> CUresult { - r#impl::context::destroy_v2(ctx.decuda()) + r#impl::context::destroy_v2(ctx.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2356,7 +2355,7 @@ pub extern "C" fn cuCtxGetCurrent(pctx: *mut CUcontext) -> CUresult { #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuCtxGetDevice(device: *mut CUdevice) -> CUresult { - r#impl::context::get_device(device.decuda()) + r#impl::context::get_device(device.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2404,7 +2403,7 @@ pub extern "C" fn cuCtxGetApiVersion( ctx: CUcontext, version: *mut ::std::os::raw::c_uint, ) -> CUresult { - r#impl::context::get_api_version(ctx.decuda(), version) + r#impl::context::get_api_version(ctx.decuda(), version).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2422,12 +2421,12 @@ pub extern "C" fn cuCtxResetPersistingL2Cache() -> CUresult { #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuCtxAttach(pctx: *mut CUcontext, flags: ::std::os::raw::c_uint) -> CUresult { - r#impl::unimplemented() + r#impl::context::attach(pctx.decuda(), flags).encuda() } #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuCtxDetach(ctx: CUcontext) -> CUresult { - r#impl::unimplemented() + r#impl::context::detach(ctx.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2443,7 +2442,7 @@ pub extern "C" fn cuModuleLoadData( module: *mut CUmodule, image: *const ::std::os::raw::c_void, ) -> CUresult { - r#impl::unimplemented() + r#impl::module::load_data(module.decuda(), image).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -2564,7 +2563,7 @@ pub extern "C" fn cuMemGetInfo_v2(free: *mut usize, total: *mut usize) -> CUresu #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuMemAlloc_v2(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult { - r#impl::memory::alloc_v2(dptr.decuda(), bytesize) + r#impl::memory::alloc_v2(dptr.decuda(), bytesize).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -3281,7 +3280,7 @@ pub extern "C" fn cuStreamCreate( phStream: *mut CUstream, Flags: ::std::os::raw::c_uint, ) -> CUresult { - r#impl::unimplemented() + r#impl::stream::create(phStream.decuda(), Flags).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -3311,7 +3310,7 @@ pub extern "C" fn cuStreamGetFlags( #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuStreamGetCtx(hStream: CUstream, pctx: *mut CUcontext) -> CUresult { - r#impl::unimplemented() + r#impl::stream::get_ctx(hStream.decuda(), pctx.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] @@ -3390,7 +3389,7 @@ pub extern "C" fn cuStreamSynchronize(hStream: CUstream) -> CUresult { #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuStreamDestroy_v2(hStream: CUstream) -> CUresult { - r#impl::unimplemented() + r#impl::stream::destroy_v2(hStream.decuda()).encuda() } #[cfg_attr(not(test), no_mangle)] diff --git a/notcuda/src/impl/context.rs b/notcuda/src/impl/context.rs index 91d4460..9689ecf 100644 --- a/notcuda/src/impl/context.rs +++ b/notcuda/src/impl/context.rs @@ -1,18 +1,15 @@ -use super::CUresult; -use super::{device, HasLivenessCookie, LiveCheck}; +use super::{device, stream::Stream, stream::StreamData, HasLivenessCookie, LiveCheck}; +use super::{CUresult, GlobalState}; use crate::{cuda::CUcontext, cuda_impl}; use l0::sys::ze_result_t; -use std::mem::{self, ManuallyDrop}; +use std::{cell::RefCell, num::NonZeroU32, os::raw::c_uint, ptr, sync::atomic::AtomicU32}; use std::{ - cell::RefCell, - num::NonZeroU32, - os::raw::c_uint, - ptr, - sync::{atomic::AtomicU32, Mutex}, + collections::HashSet, + mem::{self}, }; thread_local! { - pub static CONTEXT_STACK: RefCell<Vec<*const Context>> = RefCell::new(Vec::new()); + pub static CONTEXT_STACK: RefCell<Vec<*mut Context>> = RefCell::new(Vec::new()); } pub type Context = LiveCheck<ContextData>; @@ -23,6 +20,17 @@ impl HasLivenessCookie for ContextData { #[cfg(target_pointer_width = "32")] const COOKIE: usize = 0x0b643ffb; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_CONTEXT; + + fn try_drop(&mut self) -> Result<(), CUresult> { + for stream in self.streams.iter() { + let stream = unsafe { &mut **stream }; + stream.context = ptr::null_mut(); + Stream::destroy_impl(unsafe { Stream::ptr_from_inner(stream) })?; + } + Ok(()) + } } enum ContextRefCount { @@ -67,26 +75,16 @@ impl ContextRefCount { } } } - - fn is_primary(&self) -> bool { - match self { - ContextRefCount::Primary => true, - ContextRefCount::NonPrimary(_) => false, - } - } } pub struct ContextData { pub flags: AtomicU32, - pub device_index: device::Index, // This pointer is null only for a moment when constructing primary context - pub device: *const Mutex<device::Device>, - // The split between mutable / non-mutable is mainly to avoid recursive locking in cuDevicePrimaryCtxGetState - pub mutable: Mutex<ContextDataMutable>, -} - -pub struct ContextDataMutable { + pub device: *mut device::Device, ref_count: ContextRefCount, + pub default_stream: StreamData, + pub streams: HashSet<*mut StreamData>, + // All the fields below are here to support internal CUDA driver API pub cuda_manager: *mut cuda_impl::rt::ContextStateManager, pub cuda_state: *mut cuda_impl::rt::ContextState, pub cuda_dtor_cb: Option< @@ -100,63 +98,75 @@ pub struct ContextDataMutable { impl ContextData { pub fn new( + l0_ctx: &mut l0::Context, + l0_dev: &l0::Device, flags: c_uint, is_primary: bool, - dev_index: device::Index, - dev: *const Mutex<device::Device>, - ) -> Self { - ContextData { + dev: *mut device::Device, + ) -> Result<Self, CUresult> { + let default_stream = StreamData::new_unitialized(l0_ctx, l0_dev)?; + Ok(ContextData { flags: AtomicU32::new(flags), - device_index: dev_index, device: dev, - mutable: Mutex::new(ContextDataMutable { - ref_count: ContextRefCount::new(is_primary), - cuda_manager: ptr::null_mut(), - cuda_state: ptr::null_mut(), - cuda_dtor_cb: None, - }), - } + ref_count: ContextRefCount::new(is_primary), + default_stream, + streams: HashSet::new(), + cuda_manager: ptr::null_mut(), + cuda_state: ptr::null_mut(), + cuda_dtor_cb: None, + }) } } -pub fn create_v2(pctx: *mut *mut Context, flags: u32, dev_idx: device::Index) -> CUresult { +impl Context { + pub fn late_init(&mut self) { + let ctx_data = self.as_option_mut().unwrap(); + ctx_data.default_stream.context = ctx_data as *mut _; + } +} + +pub fn create_v2( + pctx: *mut *mut Context, + flags: u32, + dev_idx: device::Index, +) -> Result<(), CUresult> { if pctx == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let dev = device::get_device_ref(dev_idx); - let dev = match dev { - Ok(d) => d, - Err(e) => return e, - }; - let mut ctx = Box::new(LiveCheck::new(ContextData::new(flags, false, dev_idx, dev))); - let ctx_ref = ctx.as_mut() as *mut Context; + let mut ctx_box = GlobalState::lock_device(dev_idx, |dev| { + let dev_ptr = dev as *mut _; + let mut ctx_box = Box::new(LiveCheck::new(ContextData::new( + &mut dev.l0_context, + &dev.base, + flags, + false, + dev_ptr as *mut _, + )?)); + ctx_box.late_init(); + Ok::<_, CUresult>(ctx_box) + })??; + let ctx_ref = ctx_box.as_mut() as *mut Context; unsafe { *pctx = ctx_ref }; - mem::forget(ctx); + mem::forget(ctx_box); CONTEXT_STACK.with(|stack| stack.borrow_mut().push(ctx_ref)); - CUresult::CUDA_SUCCESS + Ok(()) } -pub fn destroy_v2(ctx: *mut Context) -> CUresult { +pub fn destroy_v2(ctx: *mut Context) -> Result<(), CUresult> { if ctx == ptr::null_mut() { - return CUresult::CUDA_ERROR_INVALID_VALUE; + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } CONTEXT_STACK.with(|stack| { let mut stack = stack.borrow_mut(); let should_pop = match stack.last() { - Some(active_ctx) => *active_ctx == (ctx as *const _), + Some(active_ctx) => *active_ctx == (ctx as *mut _), None => false, }; if should_pop { stack.pop(); } }); - let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(ctx) }); - if !ctx_box.try_drop() { - CUresult::CUDA_ERROR_INVALID_CONTEXT - } else { - unsafe { ManuallyDrop::drop(&mut ctx_box) }; - CUresult::CUDA_SUCCESS - } + GlobalState::lock(|_| Context::destroy_impl(ctx))? } pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult { @@ -172,17 +182,6 @@ pub fn pop_current_v2(pctx: *mut *mut Context) -> CUresult { CUresult::CUDA_SUCCESS } -pub fn with_current<F: FnOnce(&ContextData) -> R, R>(f: F) -> Result<R, CUresult> { - CONTEXT_STACK.with(|stack| { - stack - .borrow() - .last() - .and_then(|c| unsafe { &**c }.as_ref()) - .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) - .map(f) - }) -} - pub fn get_current(pctx: *mut *mut Context) -> l0::Result<()> { if pctx == ptr::null_mut() { return Err(ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT); @@ -205,37 +204,53 @@ pub fn set_current(ctx: *mut Context) -> CUresult { } } -pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> CUresult { - let _ctx = match unsafe { ctx.as_mut() } { - None => return CUresult::CUDA_ERROR_INVALID_VALUE, - Some(ctx) => match ctx.as_mut() { - None => return CUresult::CUDA_ERROR_INVALID_CONTEXT, - Some(ctx) => ctx, - }, - }; +pub fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> { + if ctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| { + unsafe { &*ctx }.as_result()?; + Ok::<_, CUresult>(()) + })??; //TODO: query device for properties roughly matching CUDA API version unsafe { *version = 1100 }; - CUresult::CUDA_SUCCESS + Ok(()) } -pub fn get_device(dev: *mut device::Index) -> CUresult { - let dev_idx = with_current(|ctx| ctx.device_index); - match dev_idx { - Ok(idx) => { - unsafe { *dev = idx } - CUresult::CUDA_SUCCESS - } - Err(err) => err, +pub fn get_device(dev: *mut device::Index) -> Result<(), CUresult> { + let dev_idx = GlobalState::lock_current_context(|ctx| unsafe { &*ctx.device }.index)?; + unsafe { *dev = dev_idx }; + Ok(()) +} + +pub fn attach(pctx: *mut *mut Context, _flags: c_uint) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } + let ctx = GlobalState::lock_current_context_unchecked(|unchecked_ctx| { + let ctx = unchecked_ctx.as_result_mut()?; + ctx.ref_count.incr()?; + Ok::<_, CUresult>(unchecked_ctx as *mut _) + })??; + unsafe { *pctx = ctx }; + Ok(()) } -#[cfg(test)] -pub fn is_context_stack_empty() -> bool { - CONTEXT_STACK.with(|stack| stack.borrow().is_empty()) +pub fn detach(pctx: *mut Context) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock_current_context_unchecked(|unchecked_ctx| { + let ctx = unchecked_ctx.as_result_mut()?; + if ctx.ref_count.decr() { + Context::destroy_impl(unchecked_ctx)?; + } + Ok::<_, CUresult>(()) + })? } #[cfg(test)] -mod tests { +mod test { use super::super::test::CudaDriverFns; use super::super::CUresult; use std::{ffi::c_void, ptr}; diff --git a/notcuda/src/impl/device.rs b/notcuda/src/impl/device.rs index d4859d3..b8d263d 100644 --- a/notcuda/src/impl/device.rs +++ b/notcuda/src/impl/device.rs @@ -1,24 +1,21 @@ -use super::{context, transmute_lifetime, CUresult, Error}; +use super::{context, CUresult, GlobalState}; use crate::cuda; use cuda::{CUdevice_attribute, CUuuid_st}; use std::{ cmp, mem, os::raw::{c_char, c_int}, ptr, - sync::{ - atomic::{AtomicU32, Ordering}, - Mutex, MutexGuard, - }, + sync::atomic::{AtomicU32, Ordering}, }; const PROJECT_URL_SUFFIX: &'static str = " [github.com/vosen/notCUDA]"; -static mut DEVICES: Option<Vec<Mutex<Device>>> = None; #[repr(transparent)] -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Eq, PartialEq, Hash)] pub struct Index(pub c_int); pub struct Device { + pub index: Index, pub base: l0::Device, pub default_queue: l0::CommandQueue, pub l0_context: l0::Context, @@ -33,17 +30,19 @@ unsafe impl Send for Device {} impl Device { // Unsafe because it does not fully initalize primary_context - unsafe fn new(drv: &l0::Driver, d: l0::Device, idx: usize) -> l0::Result<Self> { + unsafe fn new(drv: &l0::Driver, l0_dev: l0::Device, idx: usize) -> Result<Self, CUresult> { let mut ctx = l0::Context::new(drv)?; - let queue = l0::CommandQueue::new(&mut ctx, &d)?; + let queue = l0::CommandQueue::new(&mut ctx, &l0_dev)?; let primary_context = context::Context::new(context::ContextData::new( + &mut ctx, + &l0_dev, 0, true, - Index(idx as c_int), - ptr::null(), - )); + ptr::null_mut(), + )?); Ok(Self { - base: d, + index: Index(idx as c_int), + base: l0_dev, default_queue: queue, l0_context: ctx, primary_context: primary_context, @@ -93,83 +92,53 @@ impl Device { Err(e) => Err(e), } } + + pub fn late_init(&mut self) { + self.primary_context.as_option_mut().unwrap().device = self as *mut _; + } } -pub fn init(driver: &l0::Driver) -> l0::Result<()> { +pub fn init(driver: &l0::Driver) -> Result<Vec<Device>, CUresult> { let ze_devices = driver.devices()?; let mut devices = ze_devices .into_iter() .enumerate() - .map(|(idx, d)| unsafe { Device::new(driver, d, idx) }.map(Mutex::new)) + .map(|(idx, d)| unsafe { Device::new(driver, d, idx) }) .collect::<Result<Vec<_>, _>>()?; - for d in devices.iter_mut() { - d.get_mut() - .unwrap() - .primary_context - .as_mut() - .unwrap() - .device = d; - } - unsafe { DEVICES = Some(devices) }; - Ok(()) -} - -fn devices() -> Result<&'static Vec<Mutex<Device>>, CUresult> { - match unsafe { &DEVICES } { - Some(devs) => Ok(devs), - None => Err(CUresult::CUDA_ERROR_NOT_INITIALIZED), - } -} - -pub fn get_device_ref(Index(dev_idx): Index) -> Result<&'static Mutex<Device>, CUresult> { - let devs = devices()?; - if dev_idx < 0 || dev_idx >= devs.len() as c_int { - return Err(CUresult::CUDA_ERROR_INVALID_DEVICE); + for dev in devices.iter_mut() { + dev.late_init(); + dev.primary_context.late_init(); } - Ok(&devs[dev_idx as usize]) + Ok(devices) } -pub fn get_device(dev_idx: Index) -> Result<MutexGuard<'static, Device>, CUresult> { - let dev = get_device_ref(dev_idx)?; - dev.lock().map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) -} - -pub fn get_count(count: *mut c_int) -> CUresult { - let len = devices().map(|d| d.len()); - match len { - Ok(len) => { - unsafe { *count = len as c_int }; - CUresult::CUDA_SUCCESS - } - Err(e) => e, - } +pub fn get_count(count: *mut c_int) -> Result<(), CUresult> { + let len = GlobalState::lock(|state| state.devices.len())?; + unsafe { *count = len as c_int }; + Ok(()) } -pub fn get(device: *mut Index, ordinal: c_int) -> CUresult { +pub fn get(device: *mut Index, ordinal: c_int) -> Result<(), CUresult> { if device == ptr::null_mut() || ordinal < 0 { - return CUresult::CUDA_ERROR_INVALID_VALUE; + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let len = devices().map(|d| d.len()); - match len { - Ok(len) if ordinal < (len as i32) => { - unsafe { *device = Index(ordinal) }; - CUresult::CUDA_SUCCESS - } - Ok(_) => CUresult::CUDA_ERROR_INVALID_VALUE, - Err(e) => e, + let len = GlobalState::lock(|state| state.devices.len())?; + if ordinal < (len as i32) { + unsafe { *device = Index(ordinal) }; + Ok(()) + } else { + Err(CUresult::CUDA_ERROR_INVALID_VALUE) } } -pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult> { +pub fn get_name(name: *mut c_char, len: i32, dev_idx: Index) -> Result<(), CUresult> { if name == ptr::null_mut() || len < 0 { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - // This is safe because devices are 'static - let name_ptr = { - let mut dev = get_device(dev)?; - let props = dev.get_properties().map_err(Into::<CUresult>::into)?; - props.name.as_ptr() - }; + let name_ptr = GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_properties()?; + Ok::<_, l0::sys::ze_result_t>(props.name.as_ptr()) + })??; let name_len = (0..256) .position(|i| unsafe { *name_ptr.add(i) } == 0) .unwrap_or(256); @@ -189,20 +158,14 @@ pub fn get_name(name: *mut c_char, len: i32, dev: Index) -> Result<(), CUresult> Ok(()) } -pub fn total_mem_v2(bytes: *mut usize, dev: Index) -> Result<(), CUresult> { +pub fn total_mem_v2(bytes: *mut usize, dev_idx: Index) -> Result<(), CUresult> { if bytes == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - // This is safe because devices are 'static - let mem_props = { - let mut dev = get_device(dev)?; - unsafe { - transmute_lifetime( - dev.get_memory_properties() - .map_err(Into::<CUresult>::into)?, - ) - } - }; + let mem_props = GlobalState::lock_device(dev_idx, |dev| { + let mem_props = dev.get_memory_properties()?; + Ok::<_, l0::sys::ze_result_t>(mem_props) + })??; let max_mem = mem_props .iter() .map(|p| p.totalSize) @@ -228,56 +191,101 @@ impl CUdevice_attribute { } } -pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Result<(), Error> { +pub fn get_attribute( + pi: *mut i32, + attrib: CUdevice_attribute, + dev_idx: Index, +) -> Result<(), CUresult> { if pi == ptr::null_mut() { - return Err(Error::Cuda(CUresult::CUDA_ERROR_INVALID_VALUE)); + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } if let Some(value) = attrib.get_static_value() { unsafe { *pi = value }; return Ok(()); } - let mut dev = get_device(dev).map_err(Error::Cuda)?; let value = match attrib { CUdevice_attribute::CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT => { - dev.get_properties().map_err(Error::L0)?.maxHardwareContexts as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_properties()?; + Ok::<_, l0::sys::ze_result_t>(props.maxHardwareContexts as i32) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT => { - let props = dev.get_properties().map_err(Error::L0)?; - (props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_properties()?; + Ok::<_, l0::sys::ze_result_t>( + (props.numSlices * props.numSubslicesPerSlice * props.numEUsPerSubslice) as i32, + ) + })?? + } + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => { + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_image_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::min( + props.maxImageDims1D, + c_int::max_value() as u32, + ) as c_int) + })?? } - CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH => cmp::min( - dev.get_image_properties() - .map_err(Error::L0)? - .maxImageDims1D, - c_int::max_value() as u32, - ) as c_int, CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupCountX) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::max( + i32::max_value() as u32, + props.maxGroupCountX, + ) as i32) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupCountY) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::max( + i32::max_value() as u32, + props.maxGroupCountY, + ) as i32) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupCountZ) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::max( + i32::max_value() as u32, + props.maxGroupCountZ, + ) as i32) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>( + cmp::max(i32::max_value() as u32, props.maxGroupSizeX) as i32, + ) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>( + cmp::max(i32::max_value() as u32, props.maxGroupSizeY) as i32, + ) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>( + cmp::max(i32::max_value() as u32, props.maxGroupSizeZ) as i32, + ) + })?? } CUdevice_attribute::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK => { - let props = dev.get_compute_properties().map_err(Error::L0)?; - cmp::max(i32::max_value() as u32, props.maxTotalGroupSize) as i32 + GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_compute_properties()?; + Ok::<_, l0::sys::ze_result_t>(cmp::max( + i32::max_value() as u32, + props.maxTotalGroupSize, + ) as i32) + })?? } _ => { // TODO: support more attributes for CUDA runtime @@ -293,14 +301,11 @@ pub fn get_attribute(pi: *mut i32, attrib: CUdevice_attribute, dev: Index) -> Re Ok(()) } -pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> { - let ze_uuid = { - get_device(dev) - .map_err(Error::Cuda)? - .get_properties() - .map_err(Error::L0)? - .uuid - }; +pub fn get_uuid(uuid: *mut CUuuid_st, dev_idx: Index) -> Result<(), CUresult> { + let ze_uuid = GlobalState::lock_device(dev_idx, |dev| { + let props = dev.get_properties()?; + Ok::<_, l0::sys::ze_result_t>(props.uuid) + })??; unsafe { *uuid = CUuuid_st { bytes: mem::transmute(ze_uuid.id), @@ -309,53 +314,39 @@ pub fn get_uuid(uuid: *mut CUuuid_st, dev: Index) -> Result<(), Error> { Ok(()) } -pub fn with_current_exclusive<F: FnOnce(&mut Device) -> R, R>(f: F) -> Result<R, CUresult> { - let dev = super::context::with_current(|ctx| ctx.device); - dev.and_then(|dev| { - unsafe { &*dev } - .try_lock() - .map(|mut dev| f(&mut dev)) - .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) - }) -} - -pub fn with_exclusive<F: FnOnce(&mut Device) -> R, R>(dev: Index, f: F) -> Result<R, CUresult> { - let dev = get_device_ref(dev)?; - dev.try_lock() - .map(|mut dev| f(&mut dev)) - .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE) -} - pub fn primary_ctx_get_state( - idx: Index, + dev_idx: Index, flags: *mut u32, active: *mut i32, ) -> Result<(), CUresult> { - let (ctx_ptr, flags_ptr) = with_exclusive(idx, |dev| { + let (is_active, flags_value) = GlobalState::lock_device(dev_idx, |dev| { // This is safe because primary context can't be dropped - let ctx_ptr = &dev.primary_context as *const _; + let ctx_ptr = &mut dev.primary_context as *mut _; let flags_ptr = (&unsafe { dev.primary_context.as_ref_unchecked() }.flags) as *const AtomicU32; - (ctx_ptr, flags_ptr) - })?; - let is_active = context::CONTEXT_STACK - .with(|stack| stack.borrow().last().map(|x| *x)) - .map(|current| current == ctx_ptr) - .unwrap_or(false); - let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed); - unsafe { *flags = flags_value }; + let is_active = context::CONTEXT_STACK + .with(|stack| stack.borrow().last().map(|x| *x)) + .map(|current| current == ctx_ptr) + .unwrap_or(false); + let flags_value = unsafe { &*flags_ptr }.load(Ordering::Relaxed); + Ok::<_, l0::sys::ze_result_t>((is_active, flags_value)) + })??; unsafe { *active = if is_active { 1 } else { 0 } }; + unsafe { *flags = flags_value }; Ok(()) } -pub fn primary_ctx_retain(pctx: *mut *mut context::Context, dev: Index) -> Result<(), CUresult> { - let ctx_ptr = with_exclusive(dev, |dev| &mut dev.primary_context as *mut _)?; +pub fn primary_ctx_retain( + pctx: *mut *mut context::Context, + dev_idx: Index, +) -> Result<(), CUresult> { + let ctx_ptr = GlobalState::lock_device(dev_idx, |dev| &mut dev.primary_context as *mut _)?; unsafe { *pctx = ctx_ptr }; Ok(()) } #[cfg(test)] -mod tests { +mod test { use super::super::test::CudaDriverFns; use super::super::CUresult; diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs index 562af37..ae9f6e3 100644 --- a/notcuda/src/impl/export_table.rs +++ b/notcuda/src/impl/export_table.rs @@ -4,7 +4,7 @@ use crate::{ cuda_impl,
};
-use super::{context, device, module, Decuda, Encuda};
+use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
@@ -110,17 +110,8 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [ VTableEntry { ptr: ptr::null() },
];
-unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult {
- cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda()
-}
-
-fn cudart_interface_fn1_impl(
- pctx: *mut *mut context::Context,
- dev: device::Index,
-) -> Result<(), CUresult> {
- let ctx_ptr = device::with_exclusive(dev, |d| &mut d.primary_context as *mut context::Context)?;
- unsafe { *pctx = ctx_ptr };
- Ok(())
+unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
+ super::unimplemented()
}
/*
@@ -200,7 +191,7 @@ unsafe extern "C" fn get_module_from_cubin( ptr1: *mut c_void,
ptr2: *mut c_void,
) -> CUresult {
- // Not sure what those twoparameters are actually used for,
+ // Not sure what those two parameters are actually used for,
// they are somehow involved in __cudaRegisterHostVar
if ptr1 != ptr::null_mut() || ptr2 != ptr::null_mut() {
return CUresult::CUDA_ERROR_NOT_SUPPORTED;
@@ -234,10 +225,13 @@ unsafe extern "C" fn get_module_from_cubin( },
Err(_) => continue,
};
- let module = module::ModuleData::compile_spirv(kernel_text_string);
+ let module = module::SpirvModule::new(kernel_text_string);
match module {
Ok(module) => {
- *result = Box::into_raw(Box::new(module));
+ match module::load_data_impl(result, module) {
+ Ok(()) => {}
+ Err(err) => return err,
+ }
return CUresult::CUDA_SUCCESS;
}
Err(_) => continue,
@@ -309,7 +303,7 @@ unsafe extern "C" fn context_local_storage_ctor( }
fn context_local_storage_ctor_impl(
- mut cu_ctx: *mut context::Context,
+ cu_ctx: *mut context::Context,
mgr: *mut cuda_impl::rt::ContextStateManager,
ctx_state: *mut cuda_impl::rt::ContextState,
dtor_cb: Option<
@@ -320,26 +314,11 @@ fn context_local_storage_ctor_impl( ),
>,
) -> Result<(), CUresult> {
- if cu_ctx == ptr::null_mut() {
- context::get_current(&mut cu_ctx)?;
- }
- if cu_ctx == ptr::null_mut() {
- return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
- }
- unsafe { &*cu_ctx }
- .as_ref()
- .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
- .and_then(|ctx| {
- ctx.mutable
- .try_lock()
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
- .map(|mut mutable| {
- mutable.cuda_manager = mgr;
- mutable.cuda_state = ctx_state;
- mutable.cuda_dtor_cb = dtor_cb;
- })
- })?;
- Ok(())
+ lock_context(cu_ctx, |ctx: &mut ContextData| {
+ ctx.cuda_manager = mgr;
+ ctx.cuda_state = ctx_state;
+ ctx.cuda_dtor_cb = dtor_cb;
+ })
}
// some kind of dtor
@@ -357,24 +336,10 @@ unsafe extern "C" fn context_local_storage_get_state( fn context_local_storage_get_state_impl(
ctx_state: *mut *mut cuda_impl::rt::ContextState,
- mut cu_ctx: *mut context::Context,
+ cu_ctx: *mut context::Context,
_: *mut cuda_impl::rt::ContextStateManager,
) -> Result<(), CUresult> {
- if cu_ctx == ptr::null_mut() {
- context::get_current(&mut cu_ctx)?;
- }
- if cu_ctx == ptr::null_mut() {
- return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
- }
- let cuda_state = unsafe { &*cu_ctx }
- .as_ref()
- .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
- .and_then(|ctx| {
- ctx.mutable
- .try_lock()
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
- .map(|mutable| mutable.cuda_state)
- })?;
+ let cuda_state = lock_context(cu_ctx, |ctx: &mut ContextData| ctx.cuda_state)?;
if cuda_state == ptr::null_mut() {
Err(CUresult::CUDA_ERROR_INVALID_VALUE)
} else {
@@ -382,3 +347,17 @@ fn context_local_storage_get_state_impl( Ok(())
}
}
+
+fn lock_context<T>(
+ cu_ctx: *mut context::Context,
+ fn_impl: impl FnOnce(&mut ContextData) -> T,
+) -> Result<T, CUresult> {
+ if cu_ctx == ptr::null_mut() {
+ GlobalState::lock_current_context(fn_impl)
+ } else {
+ GlobalState::lock(|_| {
+ let ctx = unsafe { &mut *cu_ctx }.as_result_mut()?;
+ Ok(fn_impl(ctx))
+ })?
+ }
+}
diff --git a/notcuda/src/impl/function.rs b/notcuda/src/impl/function.rs index 0ab3bea..394f806 100644 --- a/notcuda/src/impl/function.rs +++ b/notcuda/src/impl/function.rs @@ -1,11 +1,28 @@ use ::std::os::raw::{c_uint, c_void}; use std::ptr; -use super::{device, stream::Stream, CUresult}; +use super::{CUresult, GlobalState, HasLivenessCookie, LiveCheck, stream::Stream}; -pub struct Function { +pub type Function = LiveCheck<FunctionData>; + +impl HasLivenessCookie for FunctionData { + #[cfg(target_pointer_width = "64")] + const COOKIE: usize = 0x5e2ab14d5840678e; + + #[cfg(target_pointer_width = "32")] + const COOKIE: usize = 0x33e6a1e6; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE; + + fn try_drop(&mut self) -> Result<(), CUresult> { + Ok(()) + } +} + +pub struct FunctionData { pub base: l0::Kernel<'static>, pub arg_size: Vec<usize>, + pub use_shared_mem: bool, } pub fn launch_kernel( @@ -17,36 +34,43 @@ pub fn launch_kernel( block_dim_y: c_uint, block_dim_z: c_uint, shared_mem_bytes: c_uint, - strean: *mut Stream, + hstream: *mut Stream, kernel_params: *mut *mut c_void, extra: *mut *mut c_void, ) -> Result<(), CUresult> { if f == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - if shared_mem_bytes != 0 || strean != ptr::null_mut() || extra != ptr::null_mut() { + if extra != ptr::null_mut() { return Err(CUresult::CUDA_ERROR_NOT_SUPPORTED); } - let func = unsafe { &*f }; - for (i, arg_size) in func.arg_size.iter().copied().enumerate() { - unsafe { - func.base - .set_arg_raw(i as u32, arg_size, *kernel_params.add(i))? - }; - } - unsafe { &*f } - .base - .set_group_size(block_dim_x, block_dim_y, block_dim_z)?; - device::with_current_exclusive(|dev| { - let mut cmd_list = l0::CommandList::new(&mut dev.l0_context, &dev.base)?; + GlobalState::lock_stream(hstream, |stream| { + let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?; + for (i, arg_size) in func.arg_size.iter().enumerate() { + unsafe { + func.base + .set_arg_raw(i as u32, *arg_size, *kernel_params.add(i))? + }; + } + if func.use_shared_mem { + unsafe { + func.base.set_arg_raw( + func.arg_size.len() as u32, + shared_mem_bytes as usize, + ptr::null(), + )? + }; + } + func.base + .set_group_size(block_dim_x, block_dim_y, block_dim_z)?; + let mut cmd_list = stream.command_list()?; cmd_list.append_launch_kernel( - &unsafe { &*f }.base, + &mut func.base, &[grid_dim_x, grid_dim_y, grid_dim_z], None, &mut [], )?; - dev.default_queue.execute(cmd_list)?; - l0::Result::Ok(()) - })??; - Ok(()) + stream.queue.execute(cmd_list)?; + Ok(()) + })? } diff --git a/notcuda/src/impl/memory.rs b/notcuda/src/impl/memory.rs index 439b26f..62dc1cc 100644 --- a/notcuda/src/impl/memory.rs +++ b/notcuda/src/impl/memory.rs @@ -1,57 +1,34 @@ -use super::CUresult;
+use super::{stream, CUresult, GlobalState};
use std::ffi::c_void;
-pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult {
- let alloc_result = super::device::with_current_exclusive(|dev| unsafe {
- dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0)
- });
- match alloc_result {
- Ok(Ok(alloc)) => {
- unsafe { *dptr = alloc };
- CUresult::CUDA_SUCCESS
- }
- Ok(Err(e)) => e.into(),
- Err(e) => e,
- }
+pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
+ let ptr = GlobalState::lock_current_context(|ctx| {
+ let dev = unsafe { &mut *ctx.device };
+ Ok::<_, CUresult>(unsafe { dev.base.mem_alloc_device(&mut dev.l0_context, bytesize, 0) }?)
+ })??;
+ unsafe { *dptr = ptr };
+ Ok(())
}
-pub fn copy_v2(
- dst: *mut c_void,
- src: *const c_void,
- bytesize: usize,
-) -> Result<Result<(), l0::sys::ze_result_t>, CUresult> {
- super::device::with_current_exclusive(|dev| unsafe {
- memcpy_impl(
- &mut dev.l0_context,
- dst,
- src,
- bytesize,
- &dev.base,
- &mut dev.default_queue,
- )
- })
+pub fn copy_v2(dst: *mut c_void, src: *const c_void, bytesize: usize) -> Result<(), CUresult> {
+ GlobalState::lock_stream(stream::CU_STREAM_LEGACY, |stream| {
+ let mut cmd_list = stream.command_list()?;
+ unsafe { cmd_list.append_memory_copy_unsafe(dst, src, bytesize, None, &mut []) }?;
+ stream.queue.execute(cmd_list)?;
+ Ok::<_, CUresult>(())
+ })?
}
-unsafe fn memcpy_impl(
- ctx: &mut l0::Context,
- dst: *mut c_void,
- src: *const c_void,
- bytes_count: usize,
- dev: &l0::Device,
- queue: &mut l0::CommandQueue,
-) -> l0::Result<()> {
- let mut cmd_list = l0::CommandList::new(ctx, &dev)?;
- cmd_list.append_memory_copy_unsafe(dst, src, bytes_count, None, &mut [])?;
- queue.execute(cmd_list)?;
- Ok(())
-}
-
-pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> {
- Ok(())
+pub fn free_v2(ptr: *mut c_void) -> Result<(), CUresult> {
+ GlobalState::lock_current_context(|ctx| {
+ let dev = unsafe { &mut *ctx.device };
+ Ok::<_, CUresult>(unsafe { dev.l0_context.mem_free(ptr) }?)
+ })
+ .map_err(|_| CUresult::CUDA_ERROR_INVALID_VALUE)?
}
#[cfg(test)]
-mod tests {
+mod test {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
use std::ptr;
@@ -82,4 +59,20 @@ mod tests { assert_ne!(mem, ptr::null_mut());
assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
}
+
+ cuda_driver_test!(free_without_ctx);
+
+ fn free_without_ctx<T: CudaDriverFns>() {
+ assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS);
+ let mut ctx = ptr::null_mut();
+ assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS);
+ let mut mem = ptr::null_mut();
+ assert_eq!(
+ T::cuMemAlloc_v2(&mut mem, std::mem::size_of::<usize>()),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_ne!(mem, ptr::null_mut());
+ assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS);
+ assert_eq!(T::cuMemFree_v2(mem), CUresult::CUDA_ERROR_INVALID_VALUE);
+ }
}
diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs index 5a72ce4..770a32b 100644 --- a/notcuda/src/impl/mod.rs +++ b/notcuda/src/impl/mod.rs @@ -1,5 +1,15 @@ -use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st}; -use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex}; +use crate::{ + cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st}, + r#impl::device::Device, +}; +use std::{ + ffi::c_void, + mem::{self, ManuallyDrop}, + os::raw::c_int, + ptr, + sync::Mutex, + sync::TryLockError, +}; #[cfg(test)] #[macro_use] @@ -7,9 +17,9 @@ pub mod test; pub mod context; pub mod device; pub mod export_table; +pub mod function; pub mod memory; pub mod module; -pub mod function; pub mod stream; #[cfg(debug_assertions)] @@ -22,8 +32,11 @@ pub fn unimplemented() -> CUresult { CUresult::CUDA_ERROR_NOT_SUPPORTED } -pub trait HasLivenessCookie { +pub trait HasLivenessCookie: Sized { const COOKIE: usize; + const LIVENESS_FAIL: CUresult; + + fn try_drop(&mut self) -> Result<(), CUresult>; } // This struct is a best-effort check if wrapped value has been dropped, @@ -42,34 +55,55 @@ impl<T: HasLivenessCookie> LiveCheck<T> { } } + fn destroy_impl(this: *mut Self) -> Result<(), CUresult> { + let mut ctx_box = ManuallyDrop::new(unsafe { Box::from_raw(this) }); + ctx_box.try_drop()?; + unsafe { ManuallyDrop::drop(&mut ctx_box) }; + Ok(()) + } + + unsafe fn ptr_from_inner(this: *mut T) -> *mut Self { + let outer_ptr = (this as *mut u8).sub(mem::size_of::<usize>()); + outer_ptr as *mut Self + } + pub unsafe fn as_ref_unchecked(&self) -> &T { &self.data } - pub fn as_ref(&self) -> Option<&T> { + pub fn as_option_mut(&mut self) -> Option<&mut T> { if self.cookie == T::COOKIE { - Some(&self.data) + Some(&mut self.data) } else { None } } - pub fn as_mut(&mut self) -> Option<&mut T> { + pub fn as_result(&self) -> Result<&T, CUresult> { if self.cookie == T::COOKIE { - Some(&mut self.data) + Ok(&self.data) } else { - None + Err(T::LIVENESS_FAIL) + } + } + + pub fn as_result_mut(&mut self) -> Result<&mut T, CUresult> { + if self.cookie == T::COOKIE { + Ok(&mut self.data) + } else { + Err(T::LIVENESS_FAIL) } } #[must_use] - pub fn try_drop(&mut self) -> bool { + pub fn try_drop(&mut self) -> Result<(), CUresult> { if self.cookie == T::COOKIE { self.cookie = 0; + self.data.try_drop()?; unsafe { ManuallyDrop::drop(&mut self.data) }; - return true; + return Ok(()); } - false + Err(T::LIVENESS_FAIL) } } @@ -121,6 +155,12 @@ impl From<l0::sys::ze_result_t> for CUresult { } } +impl<T> From<TryLockError<T>> for CUresult { + fn from(_: TryLockError<T>) -> Self { + CUresult::CUDA_ERROR_ILLEGAL_STATE + } +} + pub trait Encuda { type To: Sized; fn encuda(self: Self) -> Self::To; @@ -157,58 +197,103 @@ impl<T1: Encuda<To = CUresult>, T2: Encuda<To = CUresult>> Encuda for Result<T1, } } -pub enum Error { - L0(l0::sys::ze_result_t), - Cuda(CUresult), -} - -impl Encuda for Error { - type To = CUresult; - fn encuda(self: Self) -> Self::To { - match self { - Error::L0(e) => e.into(), - Error::Cuda(e) => e, - } - } -} - lazy_static! { static ref GLOBAL_STATE: Mutex<Option<GlobalState>> = Mutex::new(None); } struct GlobalState { - driver: l0::Driver, + devices: Vec<Device>, } unsafe impl Send for GlobalState {} +impl GlobalState { + fn lock<T>(f: impl FnOnce(&mut GlobalState) -> T) -> Result<T, CUresult> { + let mut mutex = GLOBAL_STATE + .lock() + .unwrap_or_else(|poison| poison.into_inner()); + let global_state = mutex.as_mut().ok_or(CUresult::CUDA_ERROR_ILLEGAL_STATE)?; + Ok(f(global_state)) + } + + fn lock_device<T>( + device::Index(dev_idx): device::Index, + f: impl FnOnce(&'static mut device::Device) -> T, + ) -> Result<T, CUresult> { + if dev_idx < 0 { + return Err(CUresult::CUDA_ERROR_INVALID_DEVICE); + } + Self::lock(|global_state| { + if dev_idx >= global_state.devices.len() as c_int { + Err(CUresult::CUDA_ERROR_INVALID_DEVICE) + } else { + Ok(f(unsafe { + transmute_lifetime_mut(&mut global_state.devices[dev_idx as usize]) + })) + } + })? + } + + fn lock_current_context<F: FnOnce(&mut context::ContextData) -> R, R>( + f: F, + ) -> Result<R, CUresult> { + Self::lock_current_context_unchecked(|ctx| Ok(f(ctx.as_result_mut()?)))? + } + + fn lock_current_context_unchecked<F: FnOnce(&mut context::Context) -> R, R>( + f: F, + ) -> Result<R, CUresult> { + context::CONTEXT_STACK.with(|stack| { + stack + .borrow_mut() + .last_mut() + .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT) + .map(|ctx| GlobalState::lock(|_| f(unsafe { &mut **ctx })))? + }) + } + + fn lock_stream<T>( + stream: *mut stream::Stream, + f: impl FnOnce(&mut stream::StreamData) -> T, + ) -> Result<T, CUresult> { + if stream == ptr::null_mut() + || stream == stream::CU_STREAM_LEGACY + || stream == stream::CU_STREAM_PER_THREAD + { + Self::lock_current_context(|ctx| Ok(f(&mut ctx.default_stream)))? + } else { + Self::lock(|_| { + let stream = unsafe { &mut *stream }.as_result_mut()?; + Ok(f(stream)) + })? + } + } +} + // TODO: implement fn is_intel_gpu_driver(_: &l0::Driver) -> bool { true } -pub fn init() -> l0::Result<()> { +pub fn init() -> Result<(), CUresult> { let mut global_state = GLOBAL_STATE .lock() - .map_err(|_| l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?; + .map_err(|_| CUresult::CUDA_ERROR_UNKNOWN)?; if global_state.is_some() { return Ok(()); } l0::init()?; let drivers = l0::Driver::get()?; - let driver = match drivers.into_iter().find(is_intel_gpu_driver) { - None => return Err(l0::sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), - Some(driver) => { - device::init(&driver)?; - driver - } + let devices = match drivers.into_iter().find(is_intel_gpu_driver) { + None => return Err(CUresult::CUDA_ERROR_UNKNOWN), + Some(driver) => device::init(&driver)?, }; - *global_state = Some(GlobalState { driver }); + *global_state = Some(GlobalState { devices }); drop(global_state); Ok(()) } -unsafe fn transmute_lifetime<'a, 'b, T: ?Sized>(t: &'a T) -> &'b T { +unsafe fn transmute_lifetime_mut<'a, 'b, T: ?Sized>(t: &'a mut T) -> &'b mut T { mem::transmute(t) } diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index 35436c3..4422107 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -1,79 +1,90 @@ use std::{ - collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex, + collections::hash_map, collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, + os::raw::c_char, ptr, slice, }; -use super::{function::Function, transmute_lifetime, CUresult}; +use super::{ + device, function::Function, function::FunctionData, CUresult, GlobalState, HasLivenessCookie, + LiveCheck, +}; use ptx; -pub type Module = Mutex<ModuleData>; +pub type Module = LiveCheck<ModuleData>; + +impl HasLivenessCookie for ModuleData { + #[cfg(target_pointer_width = "64")] + const COOKIE: usize = 0xf1313bd46505f98a; + + #[cfg(target_pointer_width = "32")] + const COOKIE: usize = 0xbdbe3f15; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE; + + fn try_drop(&mut self) -> Result<(), CUresult> { + Ok(()) + } +} pub struct ModuleData { - base: l0::Module, - arg_lens: HashMap<CString, Vec<usize>>, + pub spirv: SpirvModule, + // This should be a Vec<>, but I'm feeling lazy + pub device_binaries: HashMap<device::Index, CompiledModule>, } -pub enum ModuleCompileError<'a> { - Parse( - Vec<ptx::ast::PtxError>, - Option<ptx::ParseError<usize, ptx::Token<'a>, ptx::ast::PtxError>>, - ), - Compile(ptx::TranslateError), - L0(l0::sys::ze_result_t), - CUDA(CUresult), +pub struct SpirvModule { + pub binaries: Vec<u32>, + pub kernel_info: HashMap<String, ptx::KernelInfo>, + pub should_link_ptx_impl: Option<&'static [u8]>, + pub build_options: CString, } -impl<'a> ModuleCompileError<'a> { - pub fn get_build_log(&self) {} +pub struct CompiledModule { + pub base: l0::Module, + pub kernels: HashMap<CString, Box<Function>>, } -impl<'a> From<ptx::TranslateError> for ModuleCompileError<'a> { - fn from(err: ptx::TranslateError) -> Self { - ModuleCompileError::Compile(err) +impl<L, T, E> From<ptx::ParseError<L, T, E>> for CUresult { + fn from(_: ptx::ParseError<L, T, E>) -> Self { + CUresult::CUDA_ERROR_INVALID_PTX } } -impl<'a> From<l0::sys::ze_result_t> for ModuleCompileError<'a> { - fn from(err: l0::sys::ze_result_t) -> Self { - ModuleCompileError::L0(err) +impl From<ptx::TranslateError> for CUresult { + fn from(_: ptx::TranslateError) -> Self { + CUresult::CUDA_ERROR_INVALID_PTX } } -impl<'a> From<CUresult> for ModuleCompileError<'a> { - fn from(err: CUresult) -> Self { - ModuleCompileError::CUDA(err) +impl SpirvModule { + pub fn new_raw<'a>(text: *const c_char) -> Result<Self, CUresult> { + let u8_text = unsafe { CStr::from_ptr(text) }; + let ptx_text = u8_text + .to_str() + .map_err(|_| CUresult::CUDA_ERROR_INVALID_PTX)?; + Self::new(ptx_text) } -} -impl ModuleData { - pub fn compile_spirv<'a>(ptx_text: &'a str) -> Result<Module, ModuleCompileError<'a>> { + pub fn new<'a>(ptx_text: &str) -> Result<Self, CUresult> { let mut errors = Vec::new(); - let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text); - let ast = match ast { - Err(e) => return Err(ModuleCompileError::Parse(errors, Some(e))), - Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)), - Ok(ast) => ast, - }; - let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?; + let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_text)?; + let spirv_module = ptx::to_spirv_module(ast)?; + Ok(SpirvModule { + binaries: spirv_module.assemble(), + kernel_info: spirv_module.kernel_info, + should_link_ptx_impl: spirv_module.should_link_ptx_impl, + build_options: spirv_module.build_options, + }) + } + + pub fn compile(&self, ctx: &mut l0::Context, dev: &l0::Device) -> Result<l0::Module, CUresult> { let byte_il = unsafe { - slice::from_raw_parts::<u8>( - spirv.as_ptr() as *const _, - spirv.len() * mem::size_of::<u32>(), + slice::from_raw_parts( + self.binaries.as_ptr() as *const u8, + self.binaries.len() * mem::size_of::<u32>(), ) }; - let module = super::device::with_current_exclusive(|dev| { - l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None) - }); - match module { - Ok((Ok(module), _)) => Ok(Mutex::new(Self { - base: module, - arg_lens: all_arg_lens - .into_iter() - .map(|(k, v)| (CString::new(k).unwrap(), v)) - .collect(), - })), - Ok((Err(err), _)) => Err(ModuleCompileError::from(err)), - Err(err) => Err(ModuleCompileError::from(err)), - } + let l0_module = l0::Module::build_spirv(ctx, dev, byte_il, None).0?; + Ok(l0_module) } } @@ -85,34 +96,75 @@ pub fn get_function( if hfunc == ptr::null_mut() || hmod == ptr::null_mut() || name == ptr::null() { return Err(CUresult::CUDA_ERROR_INVALID_VALUE); } - let name = unsafe { CStr::from_ptr(name) }; - let (mut kernel, args_len) = unsafe { &*hmod } - .try_lock() - .map(|module| { - Result::<_, CUresult>::Ok(( - l0::Kernel::new_resident(unsafe { transmute_lifetime(&module.base) }, name)?, - module - .arg_lens - .get(name) - .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)? - .clone(), - )) - }) - .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)??; - kernel.set_indirect_access( - l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE - | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST - | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED, - )?; - unsafe { - *hfunc = Box::into_raw(Box::new(Function { - base: kernel, - arg_size: args_len, - })) - }; + let name = unsafe { CStr::from_ptr(name) }.to_owned(); + let function: *mut Function = GlobalState::lock_current_context(|ctx| { + let module = unsafe { &mut *hmod }.as_result_mut()?; + let device = unsafe { &mut *ctx.device }; + let compiled_module = match module.device_binaries.entry(device.index) { + hash_map::Entry::Occupied(entry) => entry.into_mut(), + hash_map::Entry::Vacant(entry) => { + let new_module = CompiledModule { + base: module.spirv.compile(&mut device.l0_context, &device.base)?, + kernels: HashMap::new(), + }; + entry.insert(new_module) + } + }; + //let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) }; + let kernel = match compiled_module.kernels.entry(name) { + hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(), + hash_map::Entry::Vacant(entry) => { + let kernel_info = module + .spirv + .kernel_info + .get(unsafe { + std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes()) + }) + .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?; + let kernel = + l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?; + entry.insert(Box::new(Function::new(FunctionData { + base: kernel, + arg_size: kernel_info.arguments_sizes.clone(), + use_shared_mem: kernel_info.uses_shared_mem, + }))) + } + }; + Ok::<_, CUresult>(kernel as *mut _) + })??; + unsafe { *hfunc = function }; Ok(()) } -pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> { +pub(crate) fn load_data(pmod: *mut *mut Module, image: *const c_void) -> Result<(), CUresult> { + let spirv_data = SpirvModule::new_raw(image as *const _)?; + load_data_impl(pmod, spirv_data) +} + +pub fn load_data_impl(pmod: *mut *mut Module, spirv_data: SpirvModule) -> Result<(), CUresult> { + let module = GlobalState::lock_current_context(|ctx| { + let device = unsafe { &mut *ctx.device }; + let l0_module = spirv_data.compile(&mut device.l0_context, &device.base)?; + let mut device_binaries = HashMap::new(); + let compiled_module = CompiledModule { + base: l0_module, + kernels: HashMap::new(), + }; + device_binaries.insert(device.index, compiled_module); + let module_data = ModuleData { + spirv: spirv_data, + device_binaries, + }; + Ok::<_, CUresult>(module_data) + })??; + let module_ptr = Box::into_raw(Box::new(Module::new(module))); + unsafe { *pmod = module_ptr }; Ok(()) } + +pub(crate) fn unload(module: *mut Module) -> Result<(), CUresult> { + if module == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| Module::destroy_impl(module))? +} diff --git a/notcuda/src/impl/stream.rs b/notcuda/src/impl/stream.rs index 1844677..e212dfc 100644 --- a/notcuda/src/impl/stream.rs +++ b/notcuda/src/impl/stream.rs @@ -1,36 +1,114 @@ -use std::cell::RefCell; +use super::{ + context::{Context, ContextData}, + CUresult, GlobalState, +}; +use std::{mem, ptr}; -use device::Device; +use super::{HasLivenessCookie, LiveCheck}; -use super::device; +pub type Stream = LiveCheck<StreamData>; -pub struct Stream { - dev: *mut Device, +pub const CU_STREAM_LEGACY: *mut Stream = 1 as *mut _; +pub const CU_STREAM_PER_THREAD: *mut Stream = 2 as *mut _; + +impl HasLivenessCookie for StreamData { + #[cfg(target_pointer_width = "64")] + const COOKIE: usize = 0x512097354de18d35; + + #[cfg(target_pointer_width = "32")] + const COOKIE: usize = 0x77d5cc0b; + + const LIVENESS_FAIL: CUresult = CUresult::CUDA_ERROR_INVALID_HANDLE; + + fn try_drop(&mut self) -> Result<(), CUresult> { + if self.context != ptr::null_mut() { + let context = unsafe { &mut *self.context }; + if !context.streams.remove(&(self as *mut _)) { + return Err(CUresult::CUDA_ERROR_UNKNOWN); + } + } + Ok(()) + } +} + +pub struct StreamData { + pub context: *mut ContextData, + pub queue: l0::CommandQueue, } -pub struct DefaultStream { - streams: Vec<Option<Stream>>, +impl StreamData { + pub fn new_unitialized(ctx: &mut l0::Context, dev: &l0::Device) -> Result<Self, CUresult> { + Ok(StreamData { + context: ptr::null_mut(), + queue: l0::CommandQueue::new(ctx, dev)?, + }) + } + pub fn new(ctx: &mut ContextData) -> Result<Self, CUresult> { + let l0_ctx = &mut unsafe { &mut *ctx.device }.l0_context; + let l0_dev = &unsafe { &*ctx.device }.base; + Ok(StreamData { + context: ctx as *mut _, + queue: l0::CommandQueue::new(l0_ctx, l0_dev)?, + }) + } + + pub fn command_list(&self) -> Result<l0::CommandList, l0::sys::_ze_result_t> { + let ctx = unsafe { &mut *self.context }; + let dev = unsafe { &mut *ctx.device }; + l0::CommandList::new(&mut dev.l0_context, &dev.base) + } } -impl DefaultStream { - fn new() -> Self { - DefaultStream { - streams: Vec::new(), +impl Drop for StreamData { + fn drop(&mut self) { + if self.context == ptr::null_mut() { + return; } + unsafe { (&mut *self.context).streams.remove(&(&mut *self as *mut _)) }; } } -thread_local! { - pub static DEFAULT_STREAM: RefCell<DefaultStream> = RefCell::new(DefaultStream::new()); +pub(crate) fn get_ctx(hstream: *mut Stream, pctx: *mut *mut Context) -> Result<(), CUresult> { + if pctx == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + let ctx_ptr = GlobalState::lock_stream(hstream, |stream| stream.context)?; + if ctx_ptr == ptr::null_mut() { + return Err(CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED); + } + unsafe { *pctx = Context::ptr_from_inner(ctx_ptr) }; + Ok(()) +} + +pub(crate) fn create(phstream: *mut *mut Stream, _flags: u32) -> Result<(), CUresult> { + let stream_ptr = GlobalState::lock_current_context(|ctx| { + let mut stream_box = Box::new(Stream::new(StreamData::new(ctx)?)); + let stream_ptr = stream_box.as_mut().as_option_mut().unwrap() as *mut _; + if !ctx.streams.insert(stream_ptr) { + return Err(CUresult::CUDA_ERROR_UNKNOWN); + } + mem::forget(stream_box); + Ok::<_, CUresult>(stream_ptr) + })??; + unsafe { *phstream = Stream::ptr_from_inner(stream_ptr) }; + Ok(()) +} + +pub(crate) fn destroy_v2(pstream: *mut Stream) -> Result<(), CUresult> { + if pstream == ptr::null_mut() || pstream == CU_STREAM_LEGACY || pstream == CU_STREAM_PER_THREAD + { + return Err(CUresult::CUDA_ERROR_INVALID_VALUE); + } + GlobalState::lock(|_| Stream::destroy_impl(pstream))? } #[cfg(test)] -mod tests { +mod test { use crate::cuda::CUstream; use super::super::test::CudaDriverFns; use super::super::CUresult; - use std::ptr; + use std::{ptr, thread}; const CU_STREAM_LEGACY: CUstream = 1 as *mut _; const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _; @@ -65,5 +143,100 @@ mod tests { CUresult::CUDA_SUCCESS ); assert_eq!(ctx2, stream_ctx2); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx1), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(stream_context_destroyed); + + fn stream_context_destroyed<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut stream = ptr::null_mut(); + assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS); + let mut stream_ctx1 = ptr::null_mut(); + assert_eq!( + T::cuStreamGetCtx(stream, &mut stream_ctx1), + CUresult::CUDA_SUCCESS + ); + assert_eq!(stream_ctx1, ctx); + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + let mut stream_ctx2 = ptr::null_mut(); + // When a context gets destroyed, its streams are also destroyed + let cuda_result = T::cuStreamGetCtx(stream, &mut stream_ctx2); + assert!( + cuda_result == CUresult::CUDA_ERROR_INVALID_HANDLE + || cuda_result == CUresult::CUDA_ERROR_INVALID_CONTEXT + || cuda_result == CUresult::CUDA_ERROR_CONTEXT_IS_DESTROYED + ); + assert_eq!( + T::cuStreamDestroy_v2(stream), + CUresult::CUDA_ERROR_INVALID_HANDLE + ); + // Check if creating another context is possible + let mut ctx2 = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx2, 0, 0), CUresult::CUDA_SUCCESS); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx2), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(stream_moves_context_to_another_thread); + + fn stream_moves_context_to_another_thread<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut stream = ptr::null_mut(); + assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS); + let mut stream_ctx1 = ptr::null_mut(); + assert_eq!( + T::cuStreamGetCtx(stream, &mut stream_ctx1), + CUresult::CUDA_SUCCESS + ); + assert_eq!(stream_ctx1, ctx); + let stream_ptr = stream as usize; + let stream_ctx_on_thread = thread::spawn(move || { + let mut stream_ctx2 = ptr::null_mut(); + assert_eq!( + T::cuStreamGetCtx(stream_ptr as *mut _, &mut stream_ctx2), + CUresult::CUDA_SUCCESS + ); + stream_ctx2 as usize + }) + .join() + .unwrap(); + assert_eq!(stream_ctx1, stream_ctx_on_thread as *mut _); + // Cleanup + assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(can_destroy_stream); + + fn can_destroy_stream<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + let mut stream = ptr::null_mut(); + assert_eq!(T::cuStreamCreate(&mut stream, 0), CUresult::CUDA_SUCCESS); + assert_eq!(T::cuStreamDestroy_v2(stream), CUresult::CUDA_SUCCESS); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); + } + + cuda_driver_test!(cant_destroy_default_stream); + + fn cant_destroy_default_stream<T: CudaDriverFns>() { + assert_eq!(T::cuInit(0), CUresult::CUDA_SUCCESS); + let mut ctx = ptr::null_mut(); + assert_eq!(T::cuCtxCreate_v2(&mut ctx, 0, 0), CUresult::CUDA_SUCCESS); + assert_ne!( + T::cuStreamDestroy_v2(super::CU_STREAM_LEGACY as *mut _), + CUresult::CUDA_SUCCESS + ); + // Cleanup + assert_eq!(T::cuCtxDestroy_v2(ctx), CUresult::CUDA_SUCCESS); } } diff --git a/notcuda/src/impl/test.rs b/notcuda/src/impl/test.rs index dbd2eff..b6ed926 100644 --- a/notcuda/src/impl/test.rs +++ b/notcuda/src/impl/test.rs @@ -1,8 +1,12 @@ #![allow(non_snake_case)] -use crate::{cuda::CUstream, r#impl as notcuda}; -use crate::r#impl::CUresult; -use crate::{cuda::CUuuid, r#impl::Encuda}; +use crate::cuda as notcuda; +use crate::cuda::CUstream; +use crate::cuda::CUuuid; +use crate::{ + cuda::{CUdevice, CUdeviceptr}, + r#impl::CUresult, +}; use ::std::{ ffi::c_void, os::raw::{c_int, c_uint}, @@ -37,48 +41,63 @@ pub trait CudaDriverFns { fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult; fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult; fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult; + fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult; + fn cuMemFree_v2(mem: *mut c_void) -> CUresult; + fn cuStreamDestroy_v2(stream: CUstream) -> CUresult; } pub struct NotCuda(); impl CudaDriverFns for NotCuda { fn cuInit(_flags: c_uint) -> CUresult { - crate::cuda::cuInit(_flags as _) + notcuda::cuInit(_flags as _) } fn cuCtxCreate_v2(pctx: *mut *mut c_void, flags: c_uint, dev: c_int) -> CUresult { - notcuda::context::create_v2(pctx as *mut _, flags, notcuda::device::Index(dev)).encuda() + notcuda::cuCtxCreate_v2(pctx as *mut _, flags, CUdevice(dev)) } fn cuCtxDestroy_v2(ctx: *mut c_void) -> CUresult { - notcuda::context::destroy_v2(ctx as *mut _) + notcuda::cuCtxDestroy_v2(ctx as *mut _) } fn cuCtxPopCurrent_v2(pctx: *mut *mut c_void) -> CUresult { - notcuda::context::pop_current_v2(pctx as *mut _) + notcuda::cuCtxPopCurrent_v2(pctx as *mut _) } fn cuCtxGetApiVersion(ctx: *mut c_void, version: *mut c_uint) -> CUresult { - notcuda::context::get_api_version(ctx as *mut _, version) + notcuda::cuCtxGetApiVersion(ctx as *mut _, version) } fn cuCtxGetCurrent(pctx: *mut *mut c_void) -> CUresult { - notcuda::context::get_current(pctx as *mut _).encuda() + notcuda::cuCtxGetCurrent(pctx as *mut _) } fn cuMemAlloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> CUresult { - notcuda::memory::alloc_v2(dptr as *mut _, bytesize) + notcuda::cuMemAlloc_v2(dptr as *mut _, bytesize) } fn cuDeviceGetUuid(uuid: *mut CUuuid, dev: c_int) -> CUresult { - notcuda::device::get_uuid(uuid, notcuda::device::Index(dev)).encuda() + notcuda::cuDeviceGetUuid(uuid, CUdevice(dev)) } fn cuDevicePrimaryCtxGetState(dev: c_int, flags: *mut c_uint, active: *mut c_int) -> CUresult { - notcuda::device::primary_ctx_get_state(notcuda::device::Index(dev), flags, active).encuda() + notcuda::cuDevicePrimaryCtxGetState(CUdevice(dev), flags, active) } fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { - crate::cuda::cuStreamGetCtx(hStream, pctx as _) + notcuda::cuStreamGetCtx(hStream, pctx as _) + } + + fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult { + notcuda::cuStreamCreate(stream, flags) + } + + fn cuMemFree_v2(dptr: *mut c_void) -> CUresult { + notcuda::cuMemFree_v2(CUdeviceptr(dptr as _)) + } + + fn cuStreamDestroy_v2(stream: CUstream) -> CUresult { + notcuda::cuStreamDestroy_v2(stream) } } @@ -123,4 +142,16 @@ impl CudaDriverFns for Cuda { fn cuStreamGetCtx(hStream: CUstream, pctx: *mut *mut c_void) -> CUresult { unsafe { CUresult(cuda::cuStreamGetCtx(hStream as _, pctx as _) as c_uint) } } + + fn cuStreamCreate(stream: *mut CUstream, flags: c_uint) -> CUresult { + unsafe { CUresult(cuda::cuStreamCreate(stream as _, flags as _) as c_uint) } + } + + fn cuMemFree_v2(mem: *mut c_void) -> CUresult { + unsafe { CUresult(cuda::cuMemFree_v2(mem as _) as c_uint) } + } + + fn cuStreamDestroy_v2(stream: CUstream) -> CUresult { + unsafe { CUresult(cuda::cuStreamDestroy_v2(stream as _) as c_uint) } + } } diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 1aac8ab..591428f 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -34,8 +34,9 @@ pub use crate::ptx::ModuleParser; pub use lalrpop_util::lexer::Token; pub use lalrpop_util::ParseError; pub use rspirv::dr::Error as SpirvError; -pub use translate::TranslateError as TranslateError; -pub use translate::to_spirv; +pub use translate::to_spirv_module; +pub use translate::KernelInfo; +pub use translate::TranslateError; pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> { x.into_iter().filter_map(|x| x).collect() diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs index 0339141..0785f3e 100644 --- a/ptx/src/test/mod.rs +++ b/ptx/src/test/mod.rs @@ -12,7 +12,7 @@ fn parse_and_assert(s: &str) { fn compile_and_assert(s: &str) -> Result<(), TranslateError> { let mut errors = Vec::new(); let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap(); - crate::to_spirv(ast)?; + crate::to_spirv_module(ast)?; Ok(()) } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c0e15f2..3d0f476 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
-use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem};
+use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
@@ -450,6 +450,11 @@ pub struct Module { pub should_link_ptx_impl: Option<&'static [u8]>,
pub build_options: CString,
}
+impl Module {
+ pub fn assemble(&self) -> Vec<u32> {
+ self.spirv.assemble()
+ }
+}
pub struct KernelInfo {
pub arguments_sizes: Vec<usize>,
@@ -1046,8 +1051,12 @@ fn emit_function_header<'a>( kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
if let MethodName::Kernel(name) = func_decl.name {
- let args_lens = func_decl
- .input
+ let input_args = if !func_decl.uses_shared_mem {
+ func_decl.input.as_slice()
+ } else {
+ &func_decl.input[0..func_decl.input.len() - 1]
+ };
+ let args_lens = input_args
.iter()
.map(|param| param.v_type.size_of())
.collect();
@@ -1135,21 +1144,6 @@ fn emit_function_header<'a>( Ok(())
}
-pub fn to_spirv<'a>(
- ast: ast::Module<'a>,
-) -> Result<(Option<&'static [u8]>, Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
- let module = to_spirv_module(ast)?;
- Ok((
- module.should_link_ptx_impl,
- module.spirv.assemble(),
- module
- .kernel_info
- .into_iter()
- .map(|(k, v)| (k, v.arguments_sizes))
- .collect(),
- ))
-}
-
fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::GenericPointer);
builder.capability(spirv::Capability::Linkage);
|