aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src')
-rw-r--r--ptx/src/pass/emit_llvm.rs80
-rw-r--r--ptx/src/pass/insert_explicit_load_store.rs42
-rw-r--r--ptx/src/pass/mod.rs4
-rw-r--r--ptx/src/pass/replace_instructions_with_function_calls.rs3
-rw-r--r--ptx/src/pass/replace_known_functions.rs38
-rw-r--r--ptx/src/test/spirv_run/mod.rs6
6 files changed, 158 insertions, 15 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs
index fa011a3..2d1269d 100644
--- a/ptx/src/pass/emit_llvm.rs
+++ b/ptx/src/pass/emit_llvm.rs
@@ -96,10 +96,6 @@ impl Module {
let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
MemoryBuffer(memory_buffer)
}
-
- fn write_to_stderr(&self) {
- unsafe { LLVMDumpModule(self.get()) };
- }
}
impl Drop for Module {
@@ -183,7 +179,6 @@ pub(super) fn run<'input>(
Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
- module.write_to_stderr();
if let Err(err) = module.verify() {
panic!("{:?}", err);
}
@@ -246,6 +241,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
.map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?;
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true");
+ self.emit_fn_attribute(fn_, "uniform-work-group-size", "true");
+ self.emit_fn_attribute(fn_, "no-trapping-math", "true");
}
if let ast::MethodName::Func(name) = func_decl.name {
self.resolver.register(name, fn_);
@@ -404,6 +402,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
ptx_parser::ScalarType::BF16x2 => todo!(),
})
}
+
+ fn emit_fn_attribute(&self, llvm_object: LLVMValueRef, key: &str, value: &str) {
+ let attribute = unsafe {
+ LLVMCreateStringAttribute(
+ self.context,
+ key.as_ptr() as _,
+ key.len() as u32,
+ value.as_ptr() as _,
+ value.len() as u32,
+ )
+ };
+ unsafe { LLVMAddAttributeAtIndex(llvm_object, LLVMAttributeFunctionIndex, attribute) };
+ }
}
fn get_input_argument_type(
@@ -529,7 +540,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
- ast::Instruction::Abs { .. } => todo!(),
+ ast::Instruction::Abs { data, arguments } => self.emit_abs(data, arguments),
ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments),
ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments),
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
@@ -539,7 +550,6 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments),
- ast::Instruction::Bar { .. } => todo!(),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
@@ -559,6 +569,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Trap {} => todo!(),
// replaced by a function call
ast::Instruction::Bfe { .. }
+ | ast::Instruction::Bar { .. }
| ast::Instruction::Bfi { .. }
| ast::Instruction::Activemask { .. } => return Err(error_unreachable()),
}
@@ -1570,8 +1581,12 @@ impl<'a> MethodEmitContext<'a> {
Some(LLVMBuildFPToUI),
)
}
- ptx_parser::CvtMode::FPFromSigned(_) => todo!(),
- ptx_parser::CvtMode::FPFromUnsigned(_) => todo!(),
+ ptx_parser::CvtMode::FPFromSigned(_) => {
+ return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildSIToFP)
+ }
+ ptx_parser::CvtMode::FPFromUnsigned(_) => {
+ return self.emit_cvt_int_to_float(data.to, arguments, LLVMBuildUIToFP)
+ }
};
let src = self.resolver.value(arguments.src)?;
self.resolver.with_result(arguments.dst, |dst| unsafe {
@@ -1726,6 +1741,25 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
+ fn emit_cvt_int_to_float(
+ &mut self,
+ to: ptx_parser::ScalarType,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ llvm_func: unsafe extern "C" fn(
+ arg1: LLVMBuilderRef,
+ Val: LLVMValueRef,
+ DestTy: LLVMTypeRef,
+ Name: *const i8,
+ ) -> LLVMValueRef,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, to);
+ let src = self.resolver.value(arguments.src)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ llvm_func(self.builder, src, type_, dst)
+ });
+ Ok(())
+ }
+
fn emit_rsqrt(
&mut self,
data: ptx_parser::TypeFtz,
@@ -1994,7 +2028,7 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
}
- ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
@@ -2021,7 +2055,7 @@ impl<'a> MethodEmitContext<'a> {
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
}
- ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
@@ -2149,6 +2183,30 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
+ fn emit_abs(
+ &mut self,
+ data: ast::TypeFtz,
+ arguments: ptx_parser::AbsArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_type = get_scalar_type(self.context, data.type_);
+ let src = self.resolver.value(arguments.src)?;
+ let (prefix, intrinsic_arguments) = if data.type_.kind() == ast::ScalarKind::Float {
+ ("llvm.fabs", vec![(src, llvm_type)])
+ } else {
+ let pred = get_scalar_type(self.context, ast::ScalarType::Pred);
+ let zero = unsafe { LLVMConstInt(pred, 0, 0) };
+ ("llvm.abs", vec![(src, llvm_type), (zero, pred)])
+ };
+ let llvm_intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(data.type_));
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_.into(),
+ intrinsic_arguments,
+ )?;
+ Ok(())
+ }
+
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs
index 60c4a14..702f733 100644
--- a/ptx/src/pass/insert_explicit_load_store.rs
+++ b/ptx/src/pass/insert_explicit_load_store.rs
@@ -122,6 +122,13 @@ fn run_statement<'a, 'input>(
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
+ Statement::PtrAccess(ptr_access) => {
+ let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
+ let statement = statement.visit_map(visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(statement);
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
s => {
let new_statement = s.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
@@ -259,6 +266,41 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
Ok(ast::Instruction::Ld { data, arguments })
}
+ fn visit_ptr_access(
+ &mut self,
+ ptr_access: PtrAccess<SpirvWord>,
+ ) -> Result<PtrAccess<SpirvWord>, TranslateError> {
+ let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
+ Some(RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name,
+ }) => (*old_space, *new_space, *name),
+ Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
+ };
+ if ptr_access.state_space != old_space {
+ return Err(error_mismatched_type());
+ }
+ // Propagate space changes in dst
+ let new_dst = self
+ .resolver
+ .register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
+ self.variables.insert(
+ ptr_access.dst,
+ RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name: new_dst,
+ },
+ );
+ Ok(PtrAccess {
+ ptr_src: name,
+ dst: new_dst,
+ state_space: new_space,
+ ..ptr_access
+ })
+ }
+
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let old_space = match var.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs
index ef131b4..c32cc39 100644
--- a/ptx/src/pass/mod.rs
+++ b/ptx/src/pass/mod.rs
@@ -22,6 +22,7 @@ mod normalize_identifiers2;
mod normalize_predicates2;
mod replace_instructions_with_function_calls;
mod resolve_function_pointers;
+mod replace_known_functions;
static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_";
@@ -42,9 +43,10 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
+ let directives = replace_known_functions::run(&flat_resolver, directives);
let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
- let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
+ let directives: Vec<Directive2<'_, ptx_parser::Instruction<ptx_parser::ParsedOperand<SpirvWord>>, ptx_parser::ParsedOperand<SpirvWord>>> = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs
index 70d77d3..668cc21 100644
--- a/ptx/src/pass/replace_instructions_with_function_calls.rs
+++ b/ptx/src/pass/replace_instructions_with_function_calls.rs
@@ -104,6 +104,9 @@ fn run_instruction<'input>(
let name = ["bfi_", scalar_to_ptx_name(data)].concat();
to_call(resolver, fn_declarations, name.into(), i)?
}
+ i @ ptx_parser::Instruction::Bar { .. } => {
+ to_call(resolver, fn_declarations, "bar_sync".into(), i)?
+ }
i => i,
})
}
diff --git a/ptx/src/pass/replace_known_functions.rs b/ptx/src/pass/replace_known_functions.rs
new file mode 100644
index 0000000..56bb7e6
--- /dev/null
+++ b/ptx/src/pass/replace_known_functions.rs
@@ -0,0 +1,38 @@
+use super::{GlobalStringIdentResolver2, NormalizedDirective2, SpirvWord};
+
+pub(crate) fn run<'input>(
+ resolver: &GlobalStringIdentResolver2<'input>,
+ mut directives: Vec<NormalizedDirective2<'input>>,
+) -> Vec<NormalizedDirective2<'input>> {
+ for directive in directives.iter_mut() {
+ match directive {
+ NormalizedDirective2::Method(func) => {
+ func.import_as =
+ replace_with_ptx_impl(resolver, &func.func_decl.name, func.import_as.take());
+ }
+ _ => {}
+ }
+ }
+ directives
+}
+
+fn replace_with_ptx_impl<'input>(
+ resolver: &GlobalStringIdentResolver2<'input>,
+ fn_name: &ptx_parser::MethodName<'input, SpirvWord>,
+ name: Option<String>,
+) -> Option<String> {
+ let known_names = ["__assertfail"];
+ match name {
+ Some(name) if known_names.contains(&&*name) => Some(format!("__zluda_ptx_impl_{}", name)),
+ Some(name) => Some(name),
+ None => match fn_name {
+ ptx_parser::MethodName::Func(name) => match resolver.ident_map.get(name) {
+ Some(super::IdentEntry {
+ name: Some(name), ..
+ }) => Some(format!("__zluda_ptx_impl_{}", name)),
+ _ => None,
+ },
+ ptx_parser::MethodName::Kernel(..) => None,
+ },
+ }
+}
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index f4b7921..e4171cd 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -298,7 +298,7 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
let mut result = vec![0u8.into(); output.len()];
{
let dev = 0;
- let mut stream = ptr::null_mut();
+ let mut stream = unsafe { mem::zeroed() };
unsafe { hipStreamCreate(&mut stream) }.unwrap();
let mut dev_props = unsafe { mem::zeroed() };
unsafe { hipGetDevicePropertiesR0600(&mut dev_props, dev) }.unwrap();
@@ -308,9 +308,9 @@ fn run_hip<Input: From<u8> + Copy + Debug, Output: From<u8> + Copy + Debug + Def
module.linked_bitcode(),
)
.unwrap();
- let mut module = ptr::null_mut();
+ let mut module = unsafe { mem::zeroed() };
unsafe { hipModuleLoadData(&mut module, elf_module.as_ptr() as _) }.unwrap();
- let mut kernel = ptr::null_mut();
+ let mut kernel = unsafe { mem::zeroed() };
unsafe { hipModuleGetFunction(&mut kernel, module, name.as_ptr()) }.unwrap();
let mut inp_b = ptr::null_mut();
unsafe { hipMalloc(&mut inp_b, input.len() * mem::size_of::<Input>()) }.unwrap();