use cuda_types::*; use hip_runtime_sys::*; use std::{ ffi::{CStr, CString}, mem, slice, sync::OnceLock, }; use crate::r#impl::context; use super::LiveCheck; pub(crate) struct GlobalState { pub devices: Vec, } pub(crate) struct Device { pub(crate) _comgr_isa: CString, primary_context: LiveCheck, } impl Device { pub(crate) fn primary_context<'a>(&'a self) -> (&'a context::Context, CUcontext) { unsafe { ( self.primary_context.data.assume_init_ref(), self.primary_context.as_handle(), ) } } } pub(crate) fn device(dev: i32) -> Result<&'static Device, CUerror> { global_state()? .devices .get(dev as usize) .ok_or(CUerror::INVALID_DEVICE) } pub(crate) fn global_state() -> Result<&'static GlobalState, CUerror> { static GLOBAL_STATE: OnceLock> = OnceLock::new(); fn cast_slice<'a>(bytes: &'a [i8]) -> &'a [u8] { unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len()) } } GLOBAL_STATE .get_or_init(|| { let mut device_count = 0; unsafe { hipGetDeviceCount(&mut device_count) }?; Ok(GlobalState { devices: (0..device_count) .map(|i| { let mut props = unsafe { mem::zeroed() }; unsafe { hipGetDevicePropertiesR0600(&mut props, i) }?; Ok::<_, CUerror>(Device { _comgr_isa: CStr::from_bytes_until_nul(cast_slice( &props.gcnArchName[..], )) .map_err(|_| CUerror::UNKNOWN)? .to_owned(), primary_context: LiveCheck::new(context::new(i)), }) }) .collect::, _>>()?, }) }) .as_ref() .map_err(|e| *e) } pub(crate) fn init(flags: ::core::ffi::c_uint) -> CUresult { unsafe { hipInit(flags) }?; global_state()?; Ok(()) } pub(crate) fn get_version(version: &mut ::core::ffi::c_int) -> CUresult { *version = cuda_types::CUDA_VERSION as i32; Ok(()) }