diff options
author | Andrzej Janik <[email protected]> | 2024-12-05 05:43:20 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-12-05 05:43:20 +0100 |
commit | 9ca1c2da5a1fcbcaab059ee190b74d90e6575007 (patch) | |
tree | 1f5f2695e9dddfd29041fa472aa5c193d4f06b0b | |
parent | 50cfd16a0626116fa5b7380e422aa34d2a68e70b (diff) | |
download | ZLUDA-9ca1c2da5a1fcbcaab059ee190b74d90e6575007.tar.gz ZLUDA-9ca1c2da5a1fcbcaab059ee190b74d90e6575007.zip |
Resolve crashes
-rw-r--r-- | ptx/lib/zluda_ptx_impl.bc | bin | 5360 -> 7524 bytes | |||
-rw-r--r-- | ptx/lib/zluda_ptx_impl.cpp | 11 | ||||
-rw-r--r-- | ptx/src/pass/insert_explicit_load_store.rs | 42 | ||||
-rw-r--r-- | ptx/src/pass/mod.rs | 4 | ||||
-rw-r--r-- | ptx/src/pass/replace_known_functions.rs | 38 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 6 | ||||
-rw-r--r-- | zluda/src/impl/memory.rs | 4 | ||||
-rw-r--r-- | zluda/src/impl/mod.rs | 1 | ||||
-rw-r--r-- | zluda/src/lib.rs | 1 |
9 files changed, 103 insertions, 4 deletions
diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc Binary files differindex 24c20d8..6cefc81 100644 --- a/ptx/lib/zluda_ptx_impl.bc +++ b/ptx/lib/zluda_ptx_impl.bc diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp index 329a810..7af9729 100644 --- a/ptx/lib/zluda_ptx_impl.cpp +++ b/ptx/lib/zluda_ptx_impl.cpp @@ -4,6 +4,7 @@ #include <cstddef>
#include <cstdint>
+#include <hip/amd_detail/amd_device_functions.h>
#define FUNC(NAME) __device__ __attribute__((retain)) __zluda_ptx_impl_##NAME
@@ -155,4 +156,14 @@ extern "C" __builtin_amdgcn_fence(__ATOMIC_SEQ_CST, "workgroup");
__builtin_amdgcn_s_barrier();
}
+
+ void FUNC(__assertfail)(uint64_t message,
+ uint64_t file,
+ uint32_t line,
+ uint64_t function,
+ uint64_t char_size)
+ {
+ (void)char_size;
+ __assert_fail((const char *)message, (const char *)file, line, (const char *)function);
+ }
}
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index 60c4a14..702f733 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -122,6 +122,13 @@ fn run_statement<'a, 'input>( result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
+ Statement::PtrAccess(ptr_access) => {
+ let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
+ let statement = statement.visit_map(visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(statement);
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
s => {
let new_statement = s.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
@@ -259,6 +266,41 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { Ok(ast::Instruction::Ld { data, arguments })
}
+ fn visit_ptr_access(
+ &mut self,
+ ptr_access: PtrAccess<SpirvWord>,
+ ) -> Result<PtrAccess<SpirvWord>, TranslateError> {
+ let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
+ Some(RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name,
+ }) => (*old_space, *new_space, *name),
+ Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
+ };
+ if ptr_access.state_space != old_space {
+ return Err(error_mismatched_type());
+ }
+ // Propagate space changes in dst
+ let new_dst = self
+ .resolver
+ .register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
+ self.variables.insert(
+ ptr_access.dst,
+ RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name: new_dst,
+ },
+ );
+ Ok(PtrAccess {
+ ptr_src: name,
+ dst: new_dst,
+ state_space: new_space,
+ ..ptr_access
+ })
+ }
+
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let old_space = match var.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index ef131b4..c32cc39 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -22,6 +22,7 @@ mod normalize_identifiers2; mod normalize_predicates2;
mod replace_instructions_with_function_calls;
mod resolve_function_pointers;
+mod replace_known_functions;
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_";
@@ -42,9 +43,10 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
+ let directives = replace_known_functions::run(&flat_resolver, directives);
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
- let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
+ let directives: Vec<Directive2<'_, ptx_parser::Instruction<ptx_parser::ParsedOperand<SpirvWord>>, ptx_parser::ParsedOperand<SpirvWord>>> = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
diff --git a/ptx/src/pass/replace_known_functions.rs b/ptx/src/pass/replace_known_functions.rs new file mode 100644 index 0000000..56bb7e6 --- /dev/null +++ b/ptx/src/pass/replace_known_functions.rs @@ -0,0 +1,38 @@ +use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
+
+pub(crate) fn run<'input>(
+ resolver: &GlobalStringIdentResolver2<'input>,
+ mut directives: Vec<NormalizedDirective2<'input>>,
+) -> Vec<NormalizedDirective2<'input>> {
+ for directive in directives.iter_mut() {
+ match directive {
+ NormalizedDirective2::Method(func) => {
+ func.import_as =
+ replace_with_ptx_impl(resolver, &func.func_decl.name, func.import_as.take());
+ }
+ _ => {}
+ }
+ }
+ directives
+}
+
+fn replace_with_ptx_impl<'input>(
+ resolver: &GlobalStringIdentResolver2<'input>,
+ fn_name: &ptx_parser::MethodName<'input, SpirvWord>,
+ name: Option<String>,
+) -> Option<String> {
+ let known_names = ["__assertfail"];
+ match name {
+ Some(name) if known_names.contains(&&*name) => Some(format!("__zluda_ptx_impl_{}", name)),
+ Some(name) => Some(name),
+ None => match fn_name {
+ ptx_parser::MethodName::Func(name) => match resolver.ident_map.get(name) {
+ Some(super::IdentEntry {
+ name: Some(name), ..
+ }) => Some(format!("__zluda_ptx_impl_{}", name)),
+ _ => None,
+ },
+ ptx_parser::MethodName::Kernel(..) => None,
+ },
+ }
+}
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f4b7921..e4171cd 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -298,7 +298,7 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def let mut result = vec![0u8.into(); output.len()];
{
let dev = 0;
- let mut stream = ptr::null_mut();
+ let mut stream = unsafe { mem::zeroed() };
unsafe { hipStreamCreate(&mut stream) }.unwrap();
let mut dev_props = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap();
@@ -308,9 +308,9 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def module.linked_bitcode(),
)
.unwrap();
- let mut module = ptr::null_mut();
+ let mut module = unsafe { mem::zeroed() };
unsafe { hipModuleLoadData(&mut module, elf_module.as_ptr() as _) }.unwrap();
- let mut kernel = ptr::null_mut();
+ let mut kernel = unsafe { mem::zeroed() };
unsafe { hipModuleGetFunction(&mut kernel, module, name.as_ptr()) }.unwrap();
let mut inp_b = ptr::null_mut();
unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::<Input>()) }.unwrap();
diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs index 33d5a4e..18e58e7 100644 --- a/zluda/src/impl/memory.rs +++ b/zluda/src/impl/memory.rs @@ -38,3 +38,7 @@ pub(crate) fn get_address_range_v2( pub(crate) fn set_d32_v2(dst: hipDeviceptr_t, ui: ::core::ffi::c_uint, n: usize) -> hipError_t { unsafe { hipMemsetD32(dst, mem::transmute(ui), n) } } + +pub(crate) fn set_d8_v2(dst: hipDeviceptr_t, value: ::core::ffi::c_uchar, n: usize) -> hipError_t { + unsafe { hipMemsetD8(dst, value, n) } +} diff --git a/zluda/src/impl/mod.rs b/zluda/src/impl/mod.rs index 766b4a5..282f8d5 100644 --- a/zluda/src/impl/mod.rs +++ b/zluda/src/impl/mod.rs @@ -107,6 +107,7 @@ from_cuda_nop!( *const ::core::ffi::c_char, *mut ::core::ffi::c_void, *mut *mut ::core::ffi::c_void, + u8, i32, u32, usize, diff --git a/zluda/src/lib.rs b/zluda/src/lib.rs index 1f6a7ff..8efbd26 100644 --- a/zluda/src/lib.rs +++ b/zluda/src/lib.rs @@ -73,6 +73,7 @@ cuda_base::cuda_function_declarations!( cuPointerGetAttribute, cuMemGetAddressRange_v2, cuMemsetD32_v2, + cuMemsetD8_v2 ], implemented_in_function <= [ cuLaunchKernel, |