diff options
-rw-r--r-- | ptx/src/translate.rs | 1 | ||||
-rw-r--r-- | zluda/src/impl/module.rs | 40 | ||||
-rw-r--r-- | zluda_ml/src/impl.rs | 35 | ||||
-rw-r--r-- | zluda_ml/src/nvml.rs | 9 |
4 files changed, 40 insertions, 45 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ee1a1d0..66fae15 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1112,6 +1112,7 @@ fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::Int64);
builder.capability(spirv::Capability::Float16);
builder.capability(spirv::Capability::Float64);
+ builder.capability(spirv::Capability::DenormFlushToZero);
// TODO: re-enable when Intel float control extension works
//builder.capability(spirv::Capability::FunctionFloatControlINTEL);
}
diff --git a/zluda/src/impl/module.rs b/zluda/src/impl/module.rs index 5560526..6234909 100644 --- a/zluda/src/impl/module.rs +++ b/zluda/src/impl/module.rs @@ -133,7 +133,11 @@ pub(crate) fn compile_amd<'a>( return Err(io::Error::new(io::ErrorKind::Other, "")); }; let dir = tempfile::tempdir()?; - let spirv_files = spirv_il + let llvm_spirv_path = match env::var("LLVM_SPIRV") { + Ok(path) => Cow::Owned(path), + Err(_) => Cow::Borrowed(LLVM_SPIRV), + }; + let llvm_files = spirv_il .map(|spirv| { let mut spirv_file = NamedTempFile::new_in(&dir)?; let spirv_u8 = unsafe { @@ -143,24 +147,24 @@ pub(crate) fn compile_amd<'a>( ) }; spirv_file.write_all(spirv_u8)?; - Ok::<_, io::Error>(spirv_file) + if cfg!(debug_assertions) { + persist_file(spirv_file.path())?; + } + let llvm = NamedTempFile::new_in(&dir)?; + let to_llvm_cmd = Command::new(&*llvm_spirv_path) + //.arg("--spirv-debug") + .arg("-r") + .arg("-o") + .arg(llvm.path()) + .arg(spirv_file.path()) + .status()?; + assert!(to_llvm_cmd.success()); + if cfg!(debug_assertions) { + persist_file(llvm.path())?; + } + Ok::<_, io::Error>(llvm) }) .collect::<Result<Vec<_>, _>>()?; - let llvm_spirv_path = match env::var("LLVM_SPIRV") { - Ok(path) => Cow::Owned(path), - Err(_) => Cow::Borrowed(LLVM_SPIRV), - }; - let llvm = NamedTempFile::new_in(&dir)?; - let to_llvm_cmd = Command::new(&*llvm_spirv_path) - .arg("-r") - .arg("-o") - .arg(llvm.path()) - .args(spirv_files.iter().map(|f| f.path())) - .status()?; - assert!(to_llvm_cmd.success()); - if cfg!(debug_assertions) { - persist_file(llvm.path())?; - } let linked_binary = NamedTempFile::new_in(&dir)?; let mut llvm_link = PathBuf::from(AMDGPU); llvm_link.push("llvm"); @@ -171,7 +175,7 @@ pub(crate) fn compile_amd<'a>( .arg("--only-needed") .arg("-o") .arg(linked_binary.path()) - .arg(llvm.path()) + .args(llvm_files.iter().map(|f| f.path())) .args(get_bitcode_paths(device_name)); if cfg!(debug_assertions) { linker_cmd.arg("-v"); diff --git a/zluda_ml/src/impl.rs b/zluda_ml/src/impl.rs index 584d4aa..48141bd 100644 --- a/zluda_ml/src/impl.rs +++ b/zluda_ml/src/impl.rs @@ -7,6 +7,8 @@ use std::{ use crate::nvml::nvmlReturn_t;
+const VERSION: &'static [u8] = b"418.40.04";
+
macro_rules! stringify_nmvlreturn_t {
($x:ident => [ $($variant:ident),+ ]) => {
match $x {
@@ -70,7 +72,7 @@ pub(crate) fn shutdown() -> nvmlReturn_t { static mut DEVICE: Option<ocl_core::DeviceId> = None;
-pub(crate) fn init_v2() -> Result<(), nvmlReturn_t> {
+pub(crate) fn init() -> Result<(), nvmlReturn_t> {
let platforms = ocl_core::get_platform_ids()?;
let device = platforms.iter().find_map(|plat| {
let devices = ocl_core::get_device_ids(plat, Some(ocl_core::DeviceType::GPU), None).ok()?;
@@ -97,7 +99,7 @@ pub(crate) fn init_v2() -> Result<(), nvmlReturn_t> { }
pub(crate) fn init_with_flags() -> Result<(), nvmlReturn_t> {
- init_v2()
+ init()
}
impl From<ocl_core::Error> for nvmlReturn_t {
@@ -131,32 +133,15 @@ impl<T: std::io::Write> std::io::Write for CountingWriter<T> { }
}
-pub(crate) fn system_get_driver_version(
+pub(crate) unsafe fn system_get_driver_version(
version_ptr: *mut c_char,
length: c_uint,
) -> Result<(), nvmlReturn_t> {
- if version_ptr == ptr::null_mut() {
+ if version_ptr == ptr::null_mut() || length == 0 {
return Err(nvmlReturn_t::NVML_ERROR_INVALID_ARGUMENT);
}
- let device = match unsafe { DEVICE } {
- Some(d) => d,
- None => return Err(nvmlReturn_t::NVML_ERROR_UNINITIALIZED),
- };
-
- if let Ok(ocl_core::DeviceInfoResult::DriverVersion(driver_version)) =
- ocl_core::get_device_info(device, ocl_core::DeviceInfo::DriverVersion)
- {
- let output_slice =
- unsafe { slice::from_raw_parts_mut(version_ptr as *mut u8, (length - 1) as usize) };
- let mut output_write = CountingWriter {
- base: output_slice,
- len: 0,
- };
- write!(&mut output_write, "{}", driver_version)
- .map_err(|_| nvmlReturn_t::NVML_ERROR_UNKNOWN)?;
- unsafe { *(version_ptr.add(output_write.len)) = 0 };
- Ok(())
- } else {
- Err(nvmlReturn_t::NVML_ERROR_UNKNOWN)
- }
+ let strlen = usize::min(VERSION.len(), (length as usize) - 1);
+ std::ptr::copy_nonoverlapping(VERSION.as_ptr(), version_ptr as _, strlen);
+ *version_ptr.add(strlen) = 0;
+ Ok(())
}
diff --git a/zluda_ml/src/nvml.rs b/zluda_ml/src/nvml.rs index 805a51c..cab546a 100644 --- a/zluda_ml/src/nvml.rs +++ b/zluda_ml/src/nvml.rs @@ -1131,7 +1131,12 @@ pub use self::nvmlPcieLinkState_enum as nvmlPcieLinkState_t; #[no_mangle] pub extern "C" fn nvmlInit_v2() -> nvmlReturn_t { - crate::r#impl::init_v2().into() + crate::r#impl::init().into() +} + +#[no_mangle] +pub extern "C" fn nvmlInit() -> nvmlReturn_t { + crate::r#impl::init().into() } #[no_mangle] @@ -1150,7 +1155,7 @@ pub extern "C" fn nvmlErrorString(result: nvmlReturn_t) -> *const ::std::os::raw } #[no_mangle] -pub extern "C" fn nvmlSystemGetDriverVersion( +pub unsafe extern "C" fn nvmlSystemGetDriverVersion( version: *mut ::std::os::raw::c_char, length: ::std::os::raw::c_uint, ) -> nvmlReturn_t { |