aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_ml/src/impl.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_ml/src/impl.rs')
-rw-r--r--zluda_ml/src/impl.rs35
1 files changed, 10 insertions, 25 deletions
diff --git a/zluda_ml/src/impl.rs b/zluda_ml/src/impl.rs
index 584d4aa..48141bd 100644
--- a/zluda_ml/src/impl.rs
+++ b/zluda_ml/src/impl.rs
@@ -7,6 +7,8 @@ use std::{
use crate::nvml::nvmlReturn_t;
+const VERSION: &'static [u8] = b"418.40.04";
+
macro_rules! stringify_nmvlreturn_t {
($x:ident => [ $($variant:ident),+ ]) => {
match $x {
@@ -70,7 +72,7 @@ pub(crate) fn shutdown() -> nvmlReturn_t {
static mut DEVICE: Option<ocl_core::DeviceId> = None;
-pub(crate) fn init_v2() -> Result<(), nvmlReturn_t> {
+pub(crate) fn init() -> Result<(), nvmlReturn_t> {
let platforms = ocl_core::get_platform_ids()?;
let device = platforms.iter().find_map(|plat| {
let devices = ocl_core::get_device_ids(plat, Some(ocl_core::DeviceType::GPU), None).ok()?;
@@ -97,7 +99,7 @@ pub(crate) fn init_v2() -> Result<(), nvmlReturn_t> {
}
pub(crate) fn init_with_flags() -> Result<(), nvmlReturn_t> {
- init_v2()
+ init()
}
impl From<ocl_core::Error> for nvmlReturn_t {
@@ -131,32 +133,15 @@ impl<T: std::io::Write> std::io::Write for CountingWriter<T> {
}
}
-pub(crate) fn system_get_driver_version(
+pub(crate) unsafe fn system_get_driver_version(
version_ptr: *mut c_char,
length: c_uint,
) -> Result<(), nvmlReturn_t> {
- if version_ptr == ptr::null_mut() {
+ if version_ptr == ptr::null_mut() || length == 0 {
return Err(nvmlReturn_t::NVML_ERROR_INVALID_ARGUMENT);
}
- let device = match unsafe { DEVICE } {
- Some(d) => d,
- None => return Err(nvmlReturn_t::NVML_ERROR_UNINITIALIZED),
- };
-
- if let Ok(ocl_core::DeviceInfoResult::DriverVersion(driver_version)) =
- ocl_core::get_device_info(device, ocl_core::DeviceInfo::DriverVersion)
- {
- let output_slice =
- unsafe { slice::from_raw_parts_mut(version_ptr as *mut u8, (length - 1) as usize) };
- let mut output_write = CountingWriter {
- base: output_slice,
- len: 0,
- };
- write!(&mut output_write, "{}", driver_version)
- .map_err(|_| nvmlReturn_t::NVML_ERROR_UNKNOWN)?;
- unsafe { *(version_ptr.add(output_write.len)) = 0 };
- Ok(())
- } else {
- Err(nvmlReturn_t::NVML_ERROR_UNKNOWN)
- }
+ let strlen = usize::min(VERSION.len(), (length as usize) - 1);
+ std::ptr::copy_nonoverlapping(VERSION.as_ptr(), version_ptr as _, strlen);
+ *version_ptr.add(strlen) = 0;
+ Ok(())
}