aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-11-12 21:08:28 +0100
committerAndrzej Janik <[email protected]>2020-11-12 22:47:14 +0100
commita6765baa3a91b80a7724e05973e2de6746c958d7 (patch)
tree63c879e7a343b573405e3b043f43eeb7516f677e
parenta2e77fe961fb543370261e844d6cd79e0269e877 (diff)
downloadZLUDA-a6765baa3a91b80a7724e05973e2de6746c958d7.tar.gz
ZLUDA-a6765baa3a91b80a7724e05973e2de6746c958d7.zip
Add back erroneously removed functionality
-rw-r--r--notcuda/src/cuda.rs2
-rw-r--r--notcuda/src/impl/device.rs5
-rw-r--r--notcuda/src/impl/export_table.rs15
-rw-r--r--notcuda/src/impl/module.rs8
-rw-r--r--ptx/src/translate.rs4
5 files changed, 26 insertions, 8 deletions
diff --git a/notcuda/src/cuda.rs b/notcuda/src/cuda.rs
index 335da4a..a528981 100644
--- a/notcuda/src/cuda.rs
+++ b/notcuda/src/cuda.rs
@@ -2281,7 +2281,7 @@ pub extern "C" fn cuDevicePrimaryCtxRelease(dev: CUdevice) -> CUresult {
#[cfg_attr(not(test), no_mangle)]
pub extern "C" fn cuDevicePrimaryCtxRelease_v2(dev: CUdevice) -> CUresult {
- r#impl::unimplemented()
+ r#impl::device::primary_ctx_release_v2(dev.decuda())
}
#[cfg_attr(not(test), no_mangle)]
diff --git a/notcuda/src/impl/device.rs b/notcuda/src/impl/device.rs
index b8d263d..5a399dc 100644
--- a/notcuda/src/impl/device.rs
+++ b/notcuda/src/impl/device.rs
@@ -345,6 +345,11 @@ pub fn primary_ctx_retain(
Ok(())
}
+// TODO: allow for retain/reset/release of primary context
+pub(crate) fn primary_ctx_release_v2(_dev_idx: Index) -> CUresult {
+ CUresult::CUDA_SUCCESS
+}
+
#[cfg(test)]
mod test {
use super::super::test::CudaDriverFns;
diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs
index ae9f6e3..87d7f40 100644
--- a/notcuda/src/impl/export_table.rs
+++ b/notcuda/src/impl/export_table.rs
@@ -4,7 +4,7 @@ use crate::{
cuda_impl,
};
-use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
+use super::{context, context::ContextData, device, module, Decuda, Encuda, GlobalState};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
@@ -110,8 +110,17 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [
VTableEntry { ptr: ptr::null() },
];
-unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
- super::unimplemented()
+unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult {
+ cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda()
+}
+
+fn cudart_interface_fn1_impl(
+ pctx: *mut *mut context::Context,
+ dev: device::Index,
+) -> Result<(), CUresult> {
+ let ctx_ptr = GlobalState::lock_device(dev, |d| &mut d.primary_context as *mut _)?;
+ unsafe { *pctx = ctx_ptr };
+ Ok(())
}
/*
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs
index 4422107..e19d8de 100644
--- a/notcuda/src/impl/module.rs
+++ b/notcuda/src/impl/module.rs
@@ -110,7 +110,6 @@ pub fn get_function(
entry.insert(new_module)
}
};
- //let compiled_module = unsafe { transmute_lifetime_mut(compiled_module) };
let kernel = match compiled_module.kernels.entry(name) {
hash_map::Entry::Occupied(entry) => entry.into_mut().as_mut(),
hash_map::Entry::Vacant(entry) => {
@@ -121,8 +120,13 @@ pub fn get_function(
std::str::from_utf8_unchecked(entry.key().as_c_str().to_bytes())
})
.ok_or(CUresult::CUDA_ERROR_NOT_FOUND)?;
- let kernel =
+ let mut kernel =
l0::Kernel::new_resident(&compiled_module.base, entry.key().as_c_str())?;
+ kernel.set_indirect_access(
+ l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE
+ | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_HOST
+ | l0::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_SHARED
+ )?;
entry.insert(Box::new(Function::new(FunctionData {
base: kernel,
arg_size: kernel_info.arguments_sizes.clone(),
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 3d0f476..f0a3187 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,6 +1,6 @@
use crate::ast;
use half::f16;
-use rspirv::{binary::Disassemble, dr};
+use rspirv::dr;
use std::{borrow::Cow, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
@@ -6662,7 +6662,7 @@ impl ast::ScalarType {
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
- ast::ScalarType::F16x2 => ScalarKind::Float,
+ ast::ScalarType::F16x2 => ScalarKind::Float2,
ast::ScalarType::Pred => ScalarKind::Pred,
}
}