diff options
-rw-r--r-- | notcuda/src/cuda.rs | 2 | ||||
-rw-r--r-- | notcuda/src/impl/device.rs | 5 | ||||
-rw-r--r-- | notcuda/src/impl/export_table.rs | 15 | ||||
-rw-r--r-- | notcuda/src/impl/module.rs | 8 | ||||
-rw-r--r-- | ptx/src/translate.rs | 4 |
5 files changed, 26 insertions, 8 deletions
diff --git a/notcuda/src/cuda.rs b/notcuda/src/cuda.rs index 335da4a..a528981 100644 --- a/notcuda/src/cuda.rs +++ b/notcuda/src/cuda.rs @@ -2281,7 +2281,7 @@ pub extern "C" fn cuDevicePrimaryCtxRelease(dev: CUdevice) -> CUresult { #[cfg_attr(not(test), no_mangle)] pub extern "C" fn cuDevicePrimaryCtxRelease_v2(dev: CUdevice) -> CUresult { - r#impl::unimplemented() + r#impl::device::primary_ctx_release_v2(dev.decuda()) } #[cfg_attr(not(test), no_mangle)] diff --git a/notcuda/src/impl/device.rs b/notcuda/src/impl/device.rs index b8d263d..5a399dc 100644 --- a/notcuda/src/impl/device.rs +++ b/notcuda/src/impl/device.rs @@ -345,6 +345,11 @@ pub fn primary_ctx_retain( Ok(()) } +// TODO: allow for retain/reset/release of primary context +pub(crate) fn primary_ctx_release_v2(_dev_idx: Index) -> CUresult { + CUresult::CUDA_SUCCESS +} + #[cfg(test)] mod test { use super::super::test::CudaDriverFns; diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs index ae9f6e3..87d7f40 100644 --- a/notcuda/src/impl/export_table.rs +++ b/notcuda/src/impl/export_table.rs @@ -4,7 +4,7 @@ use crate::{ cuda_impl,
};
-use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
+use super::{context, context::ContextData, device, module, Decuda, Encuda, GlobalState};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
@@ -110,8 +110,17 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [ VTableEntry { ptr: ptr::null() },
];
-unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
- super::unimplemented()
+unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult {
+ cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda()
+}
+
+fn cudart_interface_fn1_impl(
+ pctx: *mut *mut context::Context,
+ dev: device::Index,
+) -> Result<(), CUresult> {
+ let ctx_ptr = GlobalState::lock_device(dev, |d| &mut d.primary_context as *mut _)?;
+ unsafe { *pctx = ctx_ptr };
+ Ok(())
}
/*
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index 4422107..e19d8de 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -110,7 +110,6 @@ pub fn get_function( entry.insert(new_module) } }; - //let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) }; let kernel = match compiled_module.kernels.entry(name) { hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(), hash_map::Entry::Vacant(entry) => { @@ -121,8 +120,13 @@ pub fn get_function( std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes()) }) .ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?; - let kernel = + let mut kernel = l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?; + kernel.set_indirect_access( + l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE + | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST + | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED + )?; entry.insert(Box::new(Function::new(FunctionData { base: kernel, arg_size: kernel_info.arguments_sizes.clone(), diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 3d0f476..f0a3187 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,6 +1,6 @@ use crate::ast;
use half::f16;
-use rspirv::{binary::Disassemble, dr};
+use rspirv::dr;
use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
@@ -6662,7 +6662,7 @@ impl ast::ScalarType { ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
- ast::ScalarType::F16x2 => ScalarKind::Float,
+ ast::ScalarType::F16x2 => ScalarKind::Float2,
ast::ScalarType::Pred => ScalarKind::Pred,
}
}
|