diff options
author | Andrzej Janik <[email protected]> | 2020-02-17 21:14:23 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-02-17 21:14:23 +0100 |
commit | 925af66b192dae89d25c7f03fead00812280169e (patch) | |
tree | a8c18cd10f0135578a66a89fa6ba478a13d82b9d | |
parent | 796e030c4eaf9efd1bde91b21be366aed523a65e (diff) | |
download | ZLUDA-925af66b192dae89d25c7f03fead00812280169e.tar.gz ZLUDA-925af66b192dae89d25c7f03fead00812280169e.zip |
Return max memory
-rw-r--r-- | notcuda/src/export_table.rs | 4 | ||||
-rw-r--r-- | notcuda/src/lib.rs | 85 | ||||
-rw-r--r-- | notcuda/src/ze.rs | 38 |
3 files changed, 110 insertions, 17 deletions
diff --git a/notcuda/src/export_table.rs b/notcuda/src/export_table.rs index 0fef013..d9abeef 100644 --- a/notcuda/src/export_table.rs +++ b/notcuda/src/export_table.rs @@ -11,10 +11,12 @@ pub unsafe extern "C" fn cuGetExportTable( ) -> cu::Result {
if *id == CU_ETID_ToolsRuntimeCallbackHooks {
*table = TABLE0.as_ptr() as *const _;
+ return cu::Result::SUCCESS;
} else if *id == CU_ETID_CudartInterface {
*table = TABLE1.as_ptr() as *const _;
+ return cu::Result::SUCCESS;
}
- return cu::Result::SUCCESS;
+ cu::Result::ERROR_NOT_SUPPORTED
}
const CU_ETID_ToolsRuntimeCallbackHooks: cu::Uuid = cu::Uuid {
diff --git a/notcuda/src/lib.rs b/notcuda/src/lib.rs index 27b34a6..bb17f8b 100644 --- a/notcuda/src/lib.rs +++ b/notcuda/src/lib.rs @@ -5,17 +5,31 @@ extern crate lazy_static; use std::sync::Mutex; use std::ptr; use std::cmp; -use std::os::raw::{c_char, c_int}; +use std::os::raw::{c_char, c_int, c_uint}; mod cu; mod export_table; +mod ze; + +use ze::Versioned; + +macro_rules! l0_check_err { + ($exp:expr) => { + { + let result = unsafe{ $exp }; + if result != l0::ze_result_t::ZE_RESULT_SUCCESS { + return Err(result); + } + } + }; +} macro_rules! l0_check { ($exp:expr) => { { let result = unsafe{ $exp }; if result != l0::ze_result_t::ZE_RESULT_SUCCESS { - return Err(result) + return result; } } }; @@ -36,11 +50,11 @@ impl Driver { fn new() -> Result<Driver, l0::ze_result_t> { let mut driver_count = 1; let mut handle = ptr::null_mut(); - l0_check!{ l0::zeDriverGet(&mut driver_count, &mut handle) }; + l0_check_err!{ l0::zeDriverGet(&mut driver_count, &mut handle) }; let mut count = 0; - l0_check! { l0::zeDeviceGet(handle, &mut count, ptr::null_mut()) } + l0_check_err! { l0::zeDeviceGet(handle, &mut count, ptr::null_mut()) } let mut devices = vec![ptr::null_mut(); count as usize]; - l0_check! { l0::zeDeviceGet(handle, &mut count, devices.as_mut_ptr()) } + l0_check_err! { l0::zeDeviceGet(handle, &mut count, devices.as_mut_ptr()) } if (count as usize) < devices.len() { devices.truncate(count as usize); } @@ -67,7 +81,7 @@ impl Driver { } fn device_get(&self, device: *mut cu::Device, ordinal: c_int) -> l0::ze_result_t { - if ordinal < 0 || (ordinal as usize) >= self.devices.len() { + if (ordinal as usize) >= self.devices.len() { return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT; } unsafe { *device = cu::Device(ordinal) }; @@ -75,35 +89,55 @@ impl Driver { } fn device_get_name(&self, name: *mut c_char, len: c_int, cu::Device(dev): cu::Device) -> l0::ze_result_t { - if len <= 0 || dev < 0 || (dev as usize) >= self.devices.len() { + if (dev as usize) >= self.devices.len() { return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT; } - let mut props = Box::new(unsafe { std::mem::zeroed::<l0::ze_device_properties_t>() }); - props.version = l0::ze_device_properties_version_t::ZE_DEVICE_PROPERTIES_VERSION_CURRENT; - let result = unsafe { l0::zeDeviceGetProperties(self.devices[dev as usize], props.as_mut()) }; - if result != l0::ze_result_t::ZE_RESULT_SUCCESS { - return result; - } + let mut props = Box::new(l0::ze_device_properties_t::new()); + l0_check! { l0::zeDeviceGetProperties(self.devices[dev as usize], props.as_mut()) }; let null_pos = props.name.iter().position(|&c| c == 0).unwrap_or(0); let dst_null_pos = cmp::min((len - 1) as usize, null_pos); unsafe { *(name.add(dst_null_pos)) = 0 }; unsafe { std::ptr::copy_nonoverlapping(props.name.as_ptr(), name, dst_null_pos) }; l0::ze_result_t::ZE_RESULT_SUCCESS } + + fn device_total_mem(&self, bytes: *mut usize, cu::Device(dev): cu::Device) -> l0::ze_result_t { + if (dev as usize) >= self.devices.len() { + return l0::ze_result_t::ZE_RESULT_ERROR_INVALID_ARGUMENT; + } + let dev = dev as usize; + let mut count = 0; + l0_check! { l0::zeDeviceGetMemoryProperties(self.devices[dev], &mut count, ptr::null_mut()) }; + if count == 0 { + return l0::ze_result_t::ZE_RESULT_ERROR_UNKNOWN; + } + let mut props = vec![l0::ze_device_memory_properties_t::new(); count as usize]; + l0_check! { l0::zeDeviceGetMemoryProperties(self.devices[dev], &mut count, props.as_mut_ptr()) }; + let iter_count = cmp::min(count as usize, props.len()); + if iter_count == 0 { + return l0::ze_result_t::ZE_RESULT_ERROR_UNKNOWN; + } + let max_mem = props.iter().take(iter_count).map(|p| p.totalSize).max().unwrap(); + unsafe { *bytes = max_mem as usize }; + l0::ze_result_t::ZE_RESULT_SUCCESS + } } #[no_mangle] -pub extern "C" fn cuDriverGetVersion(version: &mut c_int) -> cu::Result { +pub unsafe extern "C" fn cuDriverGetVersion(version: *mut c_int) -> cu::Result { + if version == ptr::null_mut() { + return cu::Result::ERROR_INVALID_VALUE; + } *version = i32::max_value(); return cu::Result::SUCCESS; } #[no_mangle] -pub unsafe extern "C" fn cuInit(_: *const c_int) -> cu::Result { +pub unsafe extern "C" fn cuInit(_: c_uint) -> cu::Result { let l0_init = l0::zeInit(l0::ze_init_flag_t::ZE_INIT_FLAG_GPU_ONLY); if l0_init != l0::ze_result_t::ZE_RESULT_SUCCESS { return cu::Result::from_l0(l0_init); - }; + } let mut lock = GLOBAL_STATE.try_lock(); if let Ok(ref mut mutex) = lock { if let None = **mutex { @@ -120,15 +154,34 @@ pub unsafe extern "C" fn cuInit(_: *const c_int) -> cu::Result { #[no_mangle] pub extern "C" fn cuDeviceGetCount(count: *mut c_int) -> cu::Result { + if count == ptr::null_mut() { + return cu::Result::ERROR_INVALID_VALUE; + } Driver::call(|driver| driver.device_get_count(count)) } #[no_mangle] pub extern "C" fn cuDeviceGet(device: *mut cu::Device, ordinal: c_int) -> cu::Result { + if ordinal < 0 || device == ptr::null_mut() { + return cu::Result::ERROR_INVALID_VALUE; + } Driver::call(|driver| driver.device_get(device, ordinal)) } #[no_mangle] pub extern "C" fn cuDeviceGetName(name: *mut c_char, len: c_int, dev: cu::Device) -> cu::Result { + let cu::Device(dev_idx) = dev; + if len <= 0 || dev_idx < 0 || name == ptr::null_mut() { + return cu::Result::ERROR_INVALID_VALUE; + } Driver::call(|driver| driver.device_get_name(name, len, dev)) +} + +#[no_mangle] +pub extern "C" fn cuDeviceTotalMem_v2(bytes: *mut usize, dev: cu::Device) -> cu::Result { + let cu::Device(dev_idx) = dev; + if dev_idx < 0 || bytes == ptr::null_mut() { + return cu::Result::ERROR_INVALID_VALUE; + } + Driver::call(|driver| driver.device_total_mem(bytes, dev)) }
\ No newline at end of file diff --git a/notcuda/src/ze.rs b/notcuda/src/ze.rs new file mode 100644 index 0000000..6d798b1 --- /dev/null +++ b/notcuda/src/ze.rs @@ -0,0 +1,38 @@ +use level_zero_sys::*; + +pub trait Versioned : Sized { + type Version; + + fn new() -> Self { + let mut result = unsafe { std::mem::zeroed::<Self>() }; + let ver = result.version(); + *ver = Self::current(); + return result; + } + + fn current() -> Self::Version; + + fn version(&mut self) -> &mut Self::Version; +} + +impl Versioned for ze_device_memory_properties_t { + type Version = ze_device_memory_properties_version_t; + fn current() -> Self::Version { + ze_device_memory_properties_version_t::ZE_DEVICE_MEMORY_PROPERTIES_VERSION_CURRENT + } + fn version(&mut self) -> &mut Self::Version { + &mut self.version + } +} + +impl Versioned for ze_device_properties_t { + type Version = ze_device_properties_version_t; + fn current() -> Self::Version { + ze_device_properties_version_t::ZE_DEVICE_PROPERTIES_VERSION_CURRENT + } + fn version(&mut self) -> &mut Self::Version { + &mut self.version + } +} + + |