From 225b9f20743274cd9474c74196112447e562f401 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 5 Dec 2024 22:06:06 +0100 Subject: Add fn perf attributes, fix typos --- ptx/src/pass/emit_llvm.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 71eb03c..2d1269d 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -241,6 +241,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), )?; fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; + self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true"); + self.emit_fn_attribute(fn_, "uniform-work-group-size", "true"); + self.emit_fn_attribute(fn_, "no-trapping-math", "true"); } if let ast::MethodName::Func(name) = func_decl.name { self.resolver.register(name, fn_); @@ -399,6 +402,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { ptx_parser::ScalarType::BF16x2 => todo!(), }) } + + fn emit_fn_attribute(&self, llvm_object: LLVMValueRef, key: &str, value: &str) { + let attribute = unsafe { + LLVMCreateStringAttribute( + self.context, + key.as_ptr() as _, + key.len() as u32, + value.as_ptr() as _, + value.len() as u32, + ) + }; + unsafe { LLVMAddAttributeAtIndex(llvm_object, LLVMAttributeFunctionIndex, attribute) }; + } } fn get_input_argument_type( @@ -2012,7 +2028,7 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { return Err(error_todo()) } - ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum", + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum", }; let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); let llvm_type = get_scalar_type(self.context, data.type_()); @@ -2039,7 +2055,7 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { return Err(error_todo()) } - ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum", + ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum", }; let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); let llvm_type = get_scalar_type(self.context, data.type_()); -- cgit v1.2.3