1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
|
use super::{driver, FromCuda, ZludaObject};
use cuda_types::*;
use hip_runtime_sys::*;
use rustc_hash::FxHashSet;
use std::{cell::RefCell, ptr, sync::Mutex};
thread_local! {
pub(crate) static CONTEXT_STACK: RefCell<Vec<(CUcontext, hipDevice_t)>> = RefCell::new(Vec::new());
}
pub(crate) struct Context {
pub(crate) device: hipDevice_t,
pub(crate) mutable: Mutex<OwnedByContext>,
}
pub(crate) struct OwnedByContext {
pub(crate) ref_count: usize, // only used by primary context
pub(crate) _memory: FxHashSet<hipDeviceptr_t>,
pub(crate) _streams: FxHashSet<hipStream_t>,
pub(crate) _modules: FxHashSet<CUmodule>,
}
impl ZludaObject for Context {
const COOKIE: usize = 0x5f867c6d9cb73315;
type CudaHandle = CUcontext;
fn drop_checked(&mut self) -> CUresult {
Ok(())
}
}
pub(crate) fn new(device: hipDevice_t) -> Context {
Context {
device,
mutable: Mutex::new(OwnedByContext {
ref_count: 0,
_memory: FxHashSet::default(),
_streams: FxHashSet::default(),
_modules: FxHashSet::default(),
}),
}
}
pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: hipLimit_t) -> hipError_t {
unsafe { hipDeviceGetLimit(pvalue, limit) }
}
pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> hipError_t {
unsafe { hipDeviceSetLimit(limit, value) }
}
pub(crate) fn synchronize() -> hipError_t {
unsafe { hipDeviceSynchronize() }
}
pub(crate) fn get_primary(hip_dev: hipDevice_t) -> Result<(&'static Context, CUcontext), CUerror> {
let dev = driver::device(hip_dev)?;
Ok(dev.primary_context())
}
pub(crate) fn set_current(raw_ctx: CUcontext) -> CUresult {
let new_device = if raw_ctx.0 == ptr::null_mut() {
CONTEXT_STACK.with(|stack| {
let mut stack = stack.borrow_mut();
if let Some((_, old_device)) = stack.pop() {
if let Some((_, new_device)) = stack.last() {
if old_device != *new_device {
return Some(*new_device);
}
}
}
None
})
} else {
let ctx: &Context = FromCuda::from_cuda(&raw_ctx)?;
let device = ctx.device;
CONTEXT_STACK.with(move |stack| {
let mut stack = stack.borrow_mut();
let last_device = stack.last().map(|(_, dev)| *dev);
stack.push((raw_ctx, device));
match last_device {
None => Some(device),
Some(last_device) if last_device != device => Some(device),
_ => None,
}
})
};
if let Some(dev) = new_device {
unsafe { hipSetDevice(dev)? };
}
Ok(())
}
|