aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--zluda_ml/src/impl.rs69
1 files changed, 47 insertions, 22 deletions
diff --git a/zluda_ml/src/impl.rs b/zluda_ml/src/impl.rs
index 1068b00..4a05c50 100644
--- a/zluda_ml/src/impl.rs
+++ b/zluda_ml/src/impl.rs
@@ -1,10 +1,12 @@
-use level_zero as l0;
use std::io::Write;
+use std::slice;
use std::{
os::raw::{c_char, c_uint},
ptr,
};
+use ocl_core::ClVersions;
+
use crate::nvml::nvmlReturn_t;
macro_rules! stringify_nmvlreturn_t {
@@ -68,22 +70,41 @@ pub(crate) fn shutdown() -> nvmlReturn_t {
nvmlReturn_t::NVML_SUCCESS
}
+static mut DEVICE: Option<ocl_core::DeviceId> = None;
+
pub(crate) fn init_v2() -> Result<(), nvmlReturn_t> {
- Ok(l0::init()?)
+ let platforms = ocl_core::get_platform_ids()?;
+ let mut device = platforms.iter().find_map(|plat| {
+ let devices = ocl_core::get_device_ids(plat, Some(ocl_core::DeviceType::GPU), None).ok()?;
+ for dev in devices {
+ let vendor = ocl_core::get_device_info(dev, ocl_core::DeviceInfo::VendorId).ok()?;
+ match vendor {
+ ocl_core::DeviceInfoResult::VendorId(0x8086)
+ | ocl_core::DeviceInfoResult::VendorId(0x1002) => {}
+ _ => continue,
+ };
+ let dev_type = ocl_core::get_device_info(dev, ocl_core::DeviceInfo::Type).ok()?;
+ if let ocl_core::DeviceInfoResult::Type(ocl_core::DeviceType::GPU) = dev_type {
+ return Some(dev);
+ }
+ }
+ None
+ });
+ unsafe { DEVICE = device };
+ if device.is_some() {
+ Ok(())
+ } else {
+ Err(nvmlReturn_t::NVML_ERROR_UNKNOWN)
+ }
}
pub(crate) fn init_with_flags() -> Result<(), nvmlReturn_t> {
init_v2()
}
-impl From<l0::sys::ze_result_t> for nvmlReturn_t {
- fn from(l0_err: l0::sys::ze_result_t) -> Self {
- match l0_err {
- l0::sys::ze_result_t::ZE_RESULT_ERROR_UNINITIALIZED => {
- nvmlReturn_t::NVML_ERROR_UNINITIALIZED
- }
- _ => nvmlReturn_t::NVML_ERROR_UNKNOWN,
- }
+impl From<ocl_core::Error> for nvmlReturn_t {
+ fn from(err: ocl_core::Error) -> Self {
+ nvmlReturn_t::NVML_ERROR_UNKNOWN
}
}
@@ -119,21 +140,25 @@ pub(crate) fn system_get_driver_version(
if version_ptr == ptr::null_mut() {
return Err(nvmlReturn_t::NVML_ERROR_INVALID_ARGUMENT);
}
- let drivers = l0::Driver::get()?;
- let output_slice =
- unsafe { std::slice::from_raw_parts_mut(version_ptr as *mut u8, (length - 1) as usize) };
- let mut output_write = CountingWriter {
- base: output_slice,
- len: 0,
+ let device = match unsafe { DEVICE } {
+ Some(d) => d,
+ None => return Err(nvmlReturn_t::NVML_ERROR_UNINITIALIZED),
};
- for d in drivers {
- let mut props = Default::default();
- d.get_properties(&mut props)?;
- let driver_version = props.driverVersion;
+
+ if let Ok(ocl_core::DeviceInfoResult::DriverVersion(driver_version)) =
+ ocl_core::get_device_info(device, ocl_core::DeviceInfo::DriverVersion)
+ {
+ let output_slice =
+ unsafe { slice::from_raw_parts_mut(version_ptr as *mut u8, (length - 1) as usize) };
+ let mut output_write = CountingWriter {
+ base: output_slice,
+ len: 0,
+ };
write!(&mut output_write, "{}", driver_version)
.map_err(|_| nvmlReturn_t::NVML_ERROR_UNKNOWN)?;
unsafe { *(version_ptr.add(output_write.len)) = 0 };
- return Ok(());
+ Ok(())
+ } else {
+ Err(nvmlReturn_t::NVML_ERROR_UNKNOWN)
}
- Err(nvmlReturn_t::NVML_ERROR_UNKNOWN)
}