aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-03-28 16:53:52 +0100
committerAndrzej Janik <[email protected]>2024-03-28 16:53:52 +0100
commit067c923408e962260bc4cfbf9374b3f75ca939ff (patch)
treeb859e69871df173253285a852d09c9999d483752
parentd41fbd50fff117b4ce0a7815aa158e8b356dbc00 (diff)
downloadZLUDA-blender_42.tar.gz
ZLUDA-blender_42.zip
Clean up, add cuCtxSetFlagsblender_42
-rw-r--r--zluda/src/cuda.rs5
-rw-r--r--zluda/src/impl/context.rs14
-rw-r--r--zluda/src/impl/device.rs8
-rw-r--r--zluda/tests/primary_context.rs20
4 files changed, 41 insertions, 6 deletions
diff --git a/zluda/src/cuda.rs b/zluda/src/cuda.rs
index 898d732..1d054c3 100644
--- a/zluda/src/cuda.rs
+++ b/zluda/src/cuda.rs
@@ -69,6 +69,7 @@ cuda_function_declarations!(
cuCtxGetDevice,
cuCtxGetLimit,
cuCtxSetLimit,
+ cuCtxSetFlags,
cuCtxGetStreamPriorityRange,
cuCtxSynchronize,
cuCtxSetCacheConfig,
@@ -485,6 +486,10 @@ mod definitions {
context::set_limit(limit, value)
}
+ pub(crate) unsafe fn cuCtxSetFlags(flags: u32) -> Result<(), CUresult> {
+ context::set_flags(flags)
+ }
+
pub(crate) unsafe fn cuCtxGetStreamPriorityRange(
leastPriority: *mut ::std::os::raw::c_int,
greatestPriority: *mut ::std::os::raw::c_int,
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs
index b16467a..d1b3e7b 100644
--- a/zluda/src/impl/context.rs
+++ b/zluda/src/impl/context.rs
@@ -222,6 +222,20 @@ pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> Result<(), CUresult>
Ok(())
}
+pub(crate) fn set_flags(flags: u32) -> Result<(), CUresult> {
+ with_current(|ctx| match ctx.variant {
+ ContextVariant::NonPrimary(ref context) => {
+ context
+ .flags
+ .store(flags, std::sync::atomic::Ordering::SeqCst);
+ Ok(())
+ }
+ // This looks stupid, but this is an actual CUDA behavior,
+ // see primary_context.rs test
+ ContextVariant::Primary(_) => Ok(()),
+ })?
+}
+
pub(crate) unsafe fn get_api_version(ctx: *mut Context, version: *mut u32) -> Result<(), CUresult> {
if ctx == ptr::null_mut() {
return Err(CUresult::CUDA_ERROR_INVALID_CONTEXT);
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index caace1d..c7e8190 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -1,6 +1,6 @@
use super::context::{ContextInnerMutable, ContextVariant, PrimaryContextData};
use super::{
- context, LiveCheck, ZludaObject, GLOBAL_STATE
+ context, LiveCheck, GLOBAL_STATE
};
use crate::r#impl::context::ContextData;
use crate::{r#impl::IntoCuda, hip_call_cuda};
@@ -12,11 +12,7 @@ use paste::paste;
use std::{
mem,
os::raw::{c_char, c_uint},
- ptr,
- sync::{
- atomic::AtomicU32,
- Mutex,
- }, ops::AddAssign, ffi::CString,
+ ptr,ffi::CString,
};
const ZLUDA_SUFFIX: &'static [u8] = b" [ZLUDA]\0";
diff --git a/zluda/tests/primary_context.rs b/zluda/tests/primary_context.rs
index 4f726a2..f72c7b1 100644
--- a/zluda/tests/primary_context.rs
+++ b/zluda/tests/primary_context.rs
@@ -18,17 +18,37 @@ unsafe fn primary_context<T: CudaDriverFns>(cuda: T) {
cuda.cuDevicePrimaryCtxSetFlags_v2(CUdevice_v1(0), 1),
CUresult::CUDA_SUCCESS
);
+ assert_eq!(
+ cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!((1, 0), (flags, active));
let mut primary_ctx = ptr::null_mut();
assert_eq!(
cuda.cuDevicePrimaryCtxRetain(&mut primary_ctx, CUdevice_v1(0)),
CUresult::CUDA_SUCCESS
);
+ assert_eq!(
+ cuda.cuCtxPushCurrent_v2(primary_ctx),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!(cuda.cuCtxSetFlags(2), CUresult::CUDA_SUCCESS);
+ assert_eq!(
+ cuda.cuCtxSetCurrent(ptr::null_mut()),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!(
+ cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),
+ CUresult::CUDA_SUCCESS
+ );
+ assert_eq!((1, 1), (flags, active));
assert_ne!(primary_ctx, ptr::null_mut());
let mut active_ctx = ptr::null_mut();
assert_eq!(
cuda.cuCtxGetCurrent(&mut active_ctx),
CUresult::CUDA_SUCCESS
);
+ assert_eq!(active_ctx, ptr::null_mut());
assert_ne!(primary_ctx, active_ctx);
assert_eq!(
cuda.cuDevicePrimaryCtxGetState(CUdevice_v1(0), &mut flags, &mut active),