diff options
Diffstat (limited to 'zluda/src/impl/driver.rs')
-rw-r--r-- | zluda/src/impl/driver.rs | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs new file mode 100644 index 0000000..7ff2f54 --- /dev/null +++ b/zluda/src/impl/driver.rs @@ -0,0 +1,79 @@ +use cuda_types::*;
+use hip_runtime_sys::*;
+use std::{
+ ffi::{CStr, CString},
+ mem, slice,
+ sync::OnceLock,
+};
+
+use crate::r#impl::context;
+
+use super::LiveCheck;
+
+pub(crate) struct GlobalState {
+ pub devices: Vec<Device>,
+}
+
+pub(crate) struct Device {
+ pub(crate) _comgr_isa: CString,
+ primary_context: LiveCheck<context::Context>,
+}
+
+impl Device {
+ pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) {
+ unsafe {
+ (
+ self.primary_context.data.assume_init_ref(),
+ self.primary_context.as_handle(),
+ )
+ }
+ }
+}
+
+pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> {
+ global_state()?
+ .devices
+ .get(dev as usize)
+ .ok_or(CUerror::INVALID_DEVICE)
+}
+
+pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> {
+ static GLOBAL_STATE: OnceLock<Result<GlobalState, CUerror>> = OnceLock::new();
+ fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] {
+ unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) }
+ }
+ GLOBAL_STATE
+ .get_or_init(|| {
+ let mut device_count = 0;
+ unsafe { hipGetDeviceCount(&mut device_count) }?;
+ Ok(GlobalState {
+ devices: (0..device_count)
+ .map(|i| {
+ let mut props = unsafe { mem::zeroed() };
+ unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?;
+ Ok::<_, CUerror>(Device {
+ _comgr_isa: CStr::from_bytes_until_nul(cast_slice(
+ &props.gcnArchName[..],
+ ))
+ .map_err(|_| CUerror::UNKNOWN)?
+ .to_owned(),
+ primary_context: LiveCheck::new(context::new(i)),
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?,
+ })
+ })
+ .as_ref()
+ .map_err(|e| *e)
+}
+
+pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult {
+ unsafe { hipInit(flags) }?;
+ global_state()?;
+ Ok(())
+}
+
+pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult {
+ *version = cuda_types::CUDA_VERSION as i32;
+ Ok(())
+}
|