diff options
Diffstat (limited to 'notcuda/src/impl/export_table.rs')
-rw-r--r-- | notcuda/src/impl/export_table.rs | 83 |
1 files changed, 31 insertions, 52 deletions
diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs index 562af37..ae9f6e3 100644 --- a/notcuda/src/impl/export_table.rs +++ b/notcuda/src/impl/export_table.rs @@ -4,7 +4,7 @@ use crate::{ cuda_impl,
};
-use super::{context, device, module, Decuda, Encuda};
+use super::{context, context::ContextData, module, Decuda, Encuda, GlobalState};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
@@ -110,17 +110,8 @@ static CUDART_INTERFACE_VTABLE: [VTableEntry; CUDART_INTERFACE_LENGTH] = [ VTableEntry { ptr: ptr::null() },
];
-unsafe extern "C" fn cudart_interface_fn1(pctx: *mut CUcontext, dev: CUdevice) -> CUresult {
- cudart_interface_fn1_impl(pctx.decuda(), dev.decuda()).encuda()
-}
-
-fn cudart_interface_fn1_impl(
- pctx: *mut *mut context::Context,
- dev: device::Index,
-) -> Result<(), CUresult> {
- let ctx_ptr = device::with_exclusive(dev, |d| &mut d.primary_context as *mut context::Context)?;
- unsafe { *pctx = ctx_ptr };
- Ok(())
+unsafe extern "C" fn cudart_interface_fn1(_pctx: *mut CUcontext, _dev: CUdevice) -> CUresult {
+ super::unimplemented()
}
/*
@@ -200,7 +191,7 @@ unsafe extern "C" fn get_module_from_cubin( ptr1: *mut c_void,
ptr2: *mut c_void,
) -> CUresult {
- // Not sure what those twoparameters are actually used for,
+ // Not sure what those two parameters are actually used for,
// they are somehow involved in __cudaRegisterHostVar
if ptr1 != ptr::null_mut() || ptr2 != ptr::null_mut() {
return CUresult::CUDA_ERROR_NOT_SUPPORTED;
@@ -234,10 +225,13 @@ unsafe extern "C" fn get_module_from_cubin( },
Err(_) => continue,
};
- let module = module::ModuleData::compile_spirv(kernel_text_string);
+ let module = module::SpirvModule::new(kernel_text_string);
match module {
Ok(module) => {
- *result = Box::into_raw(Box::new(module));
+ match module::load_data_impl(result, module) {
+ Ok(()) => {}
+ Err(err) => return err,
+ }
return CUresult::CUDA_SUCCESS;
}
Err(_) => continue,
@@ -309,7 +303,7 @@ unsafe extern "C" fn context_local_storage_ctor( }
fn context_local_storage_ctor_impl(
- mut cu_ctx: *mut context::Context,
+ cu_ctx: *mut context::Context,
mgr: *mut cuda_impl::rt::ContextStateManager,
ctx_state: *mut cuda_impl::rt::ContextState,
dtor_cb: Option<
@@ -320,26 +314,11 @@ fn context_local_storage_ctor_impl( ),
>,
) -> Result<(), CUresult> {
- if cu_ctx == ptr::null_mut() {
- context::get_current(&mut cu_ctx)?;
- }
- if cu_ctx == ptr::null_mut() {
- return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
- }
- unsafe { &*cu_ctx }
- .as_ref()
- .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
- .and_then(|ctx| {
- ctx.mutable
- .try_lock()
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
- .map(|mut mutable| {
- mutable.cuda_manager = mgr;
- mutable.cuda_state = ctx_state;
- mutable.cuda_dtor_cb = dtor_cb;
- })
- })?;
- Ok(())
+ lock_context(cu_ctx, |ctx: &mut ContextData| {
+ ctx.cuda_manager = mgr;
+ ctx.cuda_state = ctx_state;
+ ctx.cuda_dtor_cb = dtor_cb;
+ })
}
// some kind of dtor
@@ -357,24 +336,10 @@ unsafe extern "C" fn context_local_storage_get_state( fn context_local_storage_get_state_impl(
ctx_state: *mut *mut cuda_impl::rt::ContextState,
- mut cu_ctx: *mut context::Context,
+ cu_ctx: *mut context::Context,
_: *mut cuda_impl::rt::ContextStateManager,
) -> Result<(), CUresult> {
- if cu_ctx == ptr::null_mut() {
- context::get_current(&mut cu_ctx)?;
- }
- if cu_ctx == ptr::null_mut() {
- return Err(CUresult::CUDA_ERROR_INVALID_VALUE);
- }
- let cuda_state = unsafe { &*cu_ctx }
- .as_ref()
- .ok_or(CUresult::CUDA_ERROR_INVALID_CONTEXT)
- .and_then(|ctx| {
- ctx.mutable
- .try_lock()
- .map_err(|_| CUresult::CUDA_ERROR_ILLEGAL_STATE)
- .map(|mutable| mutable.cuda_state)
- })?;
+ let cuda_state = lock_context(cu_ctx, |ctx: &mut ContextData| ctx.cuda_state)?;
if cuda_state == ptr::null_mut() {
Err(CUresult::CUDA_ERROR_INVALID_VALUE)
} else {
@@ -382,3 +347,17 @@ fn context_local_storage_get_state_impl( Ok(())
}
}
+
+fn lock_context<T>(
+ cu_ctx: *mut context::Context,
+ fn_impl: impl FnOnce(&mut ContextData) -> T,
+) -> Result<T, CUresult> {
+ if cu_ctx == ptr::null_mut() {
+ GlobalState::lock_current_context(fn_impl)
+ } else {
+ GlobalState::lock(|_| {
+ let ctx = unsafe { &mut *cu_ctx }.as_result_mut()?;
+ Ok(fn_impl(ctx))
+ })?
+ }
+}
|