diff options
author | Andrzej Janik <[email protected]> | 2024-12-05 22:06:06 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-12-05 22:06:06 +0100 |
commit | 225b9f20743274cd9474c74196112447e562f401 (patch) | |
tree | 48c0110e92d59080030cbb25eff3371028d4fe96 | |
parent | 69f76bc5770ffc33af5cdda9c879e46f20f09364 (diff) | |
download | ZLUDA-225b9f20743274cd9474c74196112447e562f401.tar.gz ZLUDA-225b9f20743274cd9474c74196112447e562f401.zip |
Add fn perf attributes, fix typos
-rw-r--r-- | ptx/src/pass/emit_llvm.rs | 20 |
1 files changed, 18 insertions, 2 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 71eb03c..2d1269d 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -241,6 +241,9 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?;
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ self.emit_fn_attribute(fn_, "amdgpu-unsafe-fp-atomics", "true");
+ self.emit_fn_attribute(fn_, "uniform-work-group-size", "true");
+ self.emit_fn_attribute(fn_, "no-trapping-math", "true");
}
if let ast::MethodName::Func(name) = func_decl.name {
self.resolver.register(name, fn_);
@@ -399,6 +402,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { ptx_parser::ScalarType::BF16x2 => todo!(),
})
}
+
+ fn emit_fn_attribute(&self, llvm_object: LLVMValueRef, key: &str, value: &str) {
+ let attribute = unsafe {
+ LLVMCreateStringAttribute(
+ self.context,
+ key.as_ptr() as _,
+ key.len() as u32,
+ value.as_ptr() as _,
+ value.len() as u32,
+ )
+ };
+ unsafe { LLVMAddAttributeAtIndex(llvm_object, LLVMAttributeFunctionIndex, attribute) };
+ }
}
fn get_input_argument_type(
@@ -2012,7 +2028,7 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
}
- ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
@@ -2039,7 +2055,7 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
return Err(error_todo())
}
- ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
};
let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
|