diff options
author | Andrzej Janik <[email protected]> | 2024-03-28 16:53:52 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-03-28 16:53:52 +0100 |
commit | 067c923408e962260bc4cfbf9374b3f75ca939ff (patch) | |
tree | b859e69871df173253285a852d09c9999d483752 | |
parent | d41fbd50fff117b4ce0a7815aa158e8b356dbc00 (diff) | |
download | ZLUDA-blender_42.tar.gz ZLUDA-blender_42.zip |
Clean up, add cuCtxSetFlagsblender_42
-rw-r--r-- | zluda/src/cuda.rs | 5 | ||||
-rw-r--r-- | zluda/src/impl/context.rs | 14 | ||||
-rw-r--r-- | zluda/src/impl/device.rs | 8 | ||||
-rw-r--r-- | zluda/tests/primary_context.rs | 20 |
4 files changed, 41 insertions, 6 deletions
diff --git a/zluda/src/cuda.rs b/zluda/src/cuda.rs index 898d732..1d054c3 100644 --- a/zluda/src/cuda.rs +++ b/zluda/src/cuda.rs @@ -69,6 +69,7 @@ cuda_function_declarations!( cuCtxGetDevice,
cuCtxGetLimit,
cuCtxSetLimit,
+ cuCtxSetFlags,
cuCtxGetStreamPriorityRange,
cuCtxSynchronize,
cuCtxSetCacheConfig,
@@ -485,6 +486,10 @@ mod definitions { context::set_limit(limit, value)
}
+ pub(crate) unsafe fn cuCtxSetFlags(flags: u32) -> Result<(), CUresult> {
+ context::set_flags(flags)
+ }
+
pub(crate) unsafe fn cuCtxGetStreamPriorityRange(
leastPriority: *mut ::std::os::raw::c_int,
greatestPriority: *mut ::std::os::raw::c_int,
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs index b16467a..d1b3e7b 100644 --- a/zluda/src/impl/context.rs +++ b/zluda/src/impl/context.rs @@ -222,6 +222,20 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> Result<(), CUresult> Ok(()) } +pub(crate) fn set_flags(flags: u32) -> Result<(), CUresult> { + with_current(|ctx| match ctx.variant { + ContextVariant::NonPrimary(ref context) => { + context + .flags + .store(flags, std::sync::atomic::Ordering::SeqCst); + Ok(()) + } + // This looks stupid, but this is an actual CUDA behavior, + // see primary_context.rs test + ContextVariant::Primary(_) => Ok(()), + })? +} + pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> { if ctx == ptr::null_mut() { return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT); diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index caace1d..c7e8190 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -1,6 +1,6 @@ use super::context::{ContextInnerMutable, ContextVariant, PrimaryContextData}; use super::{ - context, LiveCheck, ZludaObject, GLOBAL_STATE + context, LiveCheck, GLOBAL_STATE }; use crate::r#impl::context::ContextData; use crate::{r#impl::IntoCuda, hip_call_cuda}; @@ -12,11 +12,7 @@ use paste::paste; use std::{ mem, os::raw::{c_char, c_uint}, - ptr, - sync::{ - atomic::AtomicU32, - Mutex, - }, ops::AddAssign, ffi::CString, + ptr,ffi::CString, }; const ZLUDA_SUFFIX: &'static [u8] = b" [ZLUDA]\0"; diff --git a/zluda/tests/primary_context.rs b/zluda/tests/primary_context.rs index 4f726a2..f72c7b1 100644 --- a/zluda/tests/primary_context.rs +++ b/zluda/tests/primary_context.rs @@ -18,17 +18,37 @@ unsafe fn primary_context<T: CudaDriverFns>(cuda: T) { cuda.cuDevicePrimaryCtxSetFlags_v2(CUdevice_v1(0), 1),
CUresult::CUDA_SUCCESS
);
+ assert_eq!(
+ cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!((1, 0), (flags, active));
let mut primary_ctx = ptr::null_mut();
assert_eq!(
cuda.cuDevicePrimaryCtxRetain(&mut primary_ctx, CUdevice_v1(0)),
CUresult::CUDA_SUCCESS
);
+ assert_eq!(
+ cuda.cuCtxPushCurrent_v2(primary_ctx),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!(cuda.cuCtxSetFlags(2), CUresult::CUDA_SUCCESS);
+ assert_eq!(
+ cuda.cuCtxSetCurrent(ptr::null_mut()),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!(
+ cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!((1, 1), (flags, active));
assert_ne!(primary_ctx, ptr::null_mut());
let mut active_ctx = ptr::null_mut();
assert_eq!(
cuda.cuCtxGetCurrent(&mut active_ctx),
CUresult::CUDA_SUCCESS
);
+ assert_eq!(active_ctx, ptr::null_mut());
assert_ne!(primary_ctx, active_ctx);
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
|