aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/context.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda/src/impl/context.rs')
-rw-r--r--zluda/src/impl/context.rs99
1 files changed, 84 insertions, 15 deletions
diff --git a/zluda/src/impl/context.rs b/zluda/src/impl/context.rs
index fffceb8..973febc 100644
--- a/zluda/src/impl/context.rs
+++ b/zluda/src/impl/context.rs
@@ -1,24 +1,93 @@
-use std::ptr;
+use super::{driver, FromCuda, ZludaObject};
+use cuda_types::*;
+use hip_runtime_sys::*;
+use rustc_hash::FxHashSet;
+use std::{cell::RefCell, ptr, sync::Mutex};
-use crate::cuda::CUlimit;
-use crate::cuda::CUresult;
+thread_local! {
+ pub(crate) static CONTEXT_STACK: RefCell<Vec<(CUcontext, hipDevice_t)>> = RefCell::new(Vec::new());
+}
+
+pub(crate) struct Context {
+ pub(crate) device: hipDevice_t,
+ pub(crate) mutable: Mutex<OwnedByContext>,
+}
+
+pub(crate) struct OwnedByContext {
+ pub(crate) ref_count: usize, // only used by primary context
+ pub(crate) _memory: FxHashSet<hipDeviceptr_t>,
+ pub(crate) _streams: FxHashSet<hipStream_t>,
+ pub(crate) _modules: FxHashSet<CUmodule>,
+}
-pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: CUlimit) -> CUresult {
- if pvalue == ptr::null_mut() {
- return CUresult::CUDA_ERROR_INVALID_VALUE;
+impl ZludaObject for Context {
+ const COOKIE: usize = 0x5f867c6d9cb73315;
+
+ type CudaHandle = CUcontext;
+
+ fn drop_checked(&mut self) -> CUresult {
+ Ok(())
}
- if limit == CUlimit::CU_LIMIT_STACK_SIZE {
- *pvalue = 512; // GTX 1060 reports 1024
- CUresult::CUDA_SUCCESS
- } else {
- CUresult::CUDA_ERROR_NOT_SUPPORTED
+}
+
+pub(crate) fn new(device: hipDevice_t) -> Context {
+ Context {
+ device,
+ mutable: Mutex::new(OwnedByContext {
+ ref_count: 0,
+ _memory: FxHashSet::default(),
+ _streams: FxHashSet::default(),
+ _modules: FxHashSet::default(),
+ }),
}
}
-pub(crate) fn set_limit(limit: CUlimit, value: usize) -> CUresult {
- if limit == CUlimit::CU_LIMIT_STACK_SIZE {
- CUresult::CUDA_SUCCESS
+pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: hipLimit_t) -> hipError_t {
+ unsafe { hipDeviceGetLimit(pvalue, limit) }
+}
+
+pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> hipError_t {
+ unsafe { hipDeviceSetLimit(limit, value) }
+}
+
+pub(crate) fn synchronize() -> hipError_t {
+ unsafe { hipDeviceSynchronize() }
+}
+
+pub(crate) fn get_primary(hip_dev: hipDevice_t) -> Result<(&'static Context, CUcontext), CUerror> {
+ let dev = driver::device(hip_dev)?;
+ Ok(dev.primary_context())
+}
+
+pub(crate) fn set_current(raw_ctx: CUcontext) -> CUresult {
+ let new_device = if raw_ctx.0 == ptr::null_mut() {
+ CONTEXT_STACK.with(|stack| {
+ let mut stack = stack.borrow_mut();
+ if let Some((_, old_device)) = stack.pop() {
+ if let Some((_, new_device)) = stack.last() {
+ if old_device != *new_device {
+ return Some(*new_device);
+ }
+ }
+ }
+ None
+ })
} else {
- CUresult::CUDA_ERROR_NOT_SUPPORTED
+ let ctx: &Context = FromCuda::from_cuda(&raw_ctx)?;
+ let device = ctx.device;
+ CONTEXT_STACK.with(move |stack| {
+ let mut stack = stack.borrow_mut();
+ let last_device = stack.last().map(|(_, dev)| *dev);
+ stack.push((raw_ctx, device));
+ match last_device {
+ None => Some(device),
+ Some(last_device) if last_device != device => Some(device),
+ _ => None,
+ }
+ })
+ };
+ if let Some(dev) = new_device {
+ unsafe { hipSetDevice(dev)? };
}
+ Ok(())
}