aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda/src/impl/context.rs
blob: e812e8558e926d08144b30f3dcbff98279257ebc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use super::{driver, FromCuda, ZludaObject};
use cuda_types::cuda::*;
use hip_runtime_sys::*;
use rustc_hash::FxHashSet;
use std::{cell::RefCell, ptr, sync::Mutex};

thread_local! {
    pub(crate) static CONTEXT_STACK: RefCell<Vec<(CUcontext, hipDevice_t)>> = RefCell::new(Vec::new());
}

pub(crate) struct Context {
    pub(crate) device: hipDevice_t,
    pub(crate) mutable: Mutex<OwnedByContext>,
}

pub(crate) struct OwnedByContext {
    pub(crate) ref_count: usize, // only used by primary context
    pub(crate) _memory: FxHashSet<hipDeviceptr_t>,
    pub(crate) _streams: FxHashSet<hipStream_t>,
    pub(crate) _modules: FxHashSet<CUmodule>,
}

impl ZludaObject for Context {
    const COOKIE: usize = 0x5f867c6d9cb73315;

    type CudaHandle = CUcontext;

    fn drop_checked(&mut self) -> CUresult {
        Ok(())
    }
}

pub(crate) fn new(device: hipDevice_t) -> Context {
    Context {
        device,
        mutable: Mutex::new(OwnedByContext {
            ref_count: 0,
            _memory: FxHashSet::default(),
            _streams: FxHashSet::default(),
            _modules: FxHashSet::default(),
        }),
    }
}

pub(crate) unsafe fn get_limit(pvalue: *mut usize, limit: hipLimit_t) -> hipError_t {
    unsafe { hipDeviceGetLimit(pvalue, limit) }
}

pub(crate) fn set_limit(limit: hipLimit_t, value: usize) -> hipError_t {
    unsafe { hipDeviceSetLimit(limit, value) }
}

pub(crate) fn synchronize() -> hipError_t {
    unsafe { hipDeviceSynchronize() }
}

pub(crate) fn get_primary(hip_dev: hipDevice_t) -> Result<(&'static Context, CUcontext), CUerror> {
    let dev = driver::device(hip_dev)?;
    Ok(dev.primary_context())
}

pub(crate) fn set_current(raw_ctx: CUcontext) -> CUresult {
    let new_device = if raw_ctx.0 == ptr::null_mut() {
        CONTEXT_STACK.with(|stack| {
            let mut stack = stack.borrow_mut();
            if let Some((_, old_device)) = stack.pop() {
                if let Some((_, new_device)) = stack.last() {
                    if old_device != *new_device {
                        return Some(*new_device);
                    }
                }
            }
            None
        })
    } else {
        let ctx: &Context = FromCuda::from_cuda(&raw_ctx)?;
        let device = ctx.device;
        CONTEXT_STACK.with(move |stack| {
            let mut stack = stack.borrow_mut();
            let last_device = stack.last().map(|(_, dev)| *dev);
            stack.push((raw_ctx, device));
            match last_device {
                None => Some(device),
                Some(last_device) if last_device != device => Some(device),
                _ => None,
            }
        })
    };
    if let Some(dev) = new_device {
        unsafe { hipSetDevice(dev)? };
    }
    Ok(())
}