aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-10-11 16:27:36 +0200
committerAndrzej Janik <[email protected]>2024-10-11 16:27:36 +0200
commitc8b88f4483eaf5ee68cd9306ca57dfaa5f7d0ce0 (patch)
tree04f9690d79869b8c00644cb31d5c6860f904d0c6
parent9035c4a24d23012b15ab3d715b25d7c5a0c43740 (diff)
downloadZLUDA-c8b88f4483eaf5ee68cd9306ca57dfaa5f7d0ce0.tar.gz
ZLUDA-c8b88f4483eaf5ee68cd9306ca57dfaa5f7d0ce0.zip
Implement div
-rw-r--r--ptx/src/pass/emit_llvm.rs184
1 files changed, 182 insertions, 2 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs
index 7c6cbb7..15177bc 100644
--- a/ptx/src/pass/emit_llvm.rs
+++ b/ptx/src/pass/emit_llvm.rs
@@ -18,6 +18,12 @@
// while with plain LLVM-C it's just:
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
+// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete.
+// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with
+// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all"
+// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel",
+// but it will too fail similarly, but with "unable to legalize instruction"
+
use std::array::TryFromSliceError;
use std::convert::TryInto;
use std::ffi::{CStr, NulError};
@@ -534,7 +540,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Bar { .. } => todo!(),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
- ast::Instruction::Div { .. } => todo!(),
+ ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
ast::Instruction::Neg { .. } => todo!(),
ast::Instruction::Sin { .. } => todo!(),
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
@@ -626,7 +632,7 @@ impl<'a> MethodEmitContext<'a> {
});
Ok(())
} else {
- let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
+ let _conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
&& to_type.kind() == ast::ScalarKind::Signed
{
if to_type.size_of() >= from_type.size_of() {
@@ -1086,6 +1092,147 @@ impl<'a> MethodEmitContext<'a> {
}
Ok(())
}
+
+ fn emit_div(
+ &mut self,
+ data: ptx_parser::DivDetails,
+ arguments: ptx_parser::DivArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let integer_div = match data {
+ ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv,
+ ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv,
+ ptx_parser::DivDetails::Float(float_div) => {
+ return self.emit_div_float(float_div, arguments)
+ }
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ integer_div(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_div_float(
+ &mut self,
+ float_div: ptx_parser::DivFloatDetails,
+ arguments: ptx_parser::DivArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let builder = self.builder;
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let _rnd = match float_div.kind {
+ ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven,
+ ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven,
+ ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode,
+ };
+ let approx = match float_div.kind {
+ ptx_parser::DivFloatKind::Approx => {
+ LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc
+ }
+ ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone,
+ ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone,
+ };
+ let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFDiv(builder, src1, src2, dst)
+ });
+ unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) };
+ if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind {
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div:
+ // div.full.f32 implements a relatively fast, full-range approximation that scales
+ // operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not
+ // support rounding modifiers. The maximum ulp error is 2 across the full range of
+ // inputs.
+ // https://llvm.org/docs/LangRef.html#fpmath-metadata
+ let fpmath_value =
+ unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) };
+ let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) };
+ let mut md_node_content = [fpmath_value];
+ let md_node = unsafe {
+ LLVMMDNodeInContext2(
+ self.context,
+ md_node_content.as_mut_ptr(),
+ md_node_content.len(),
+ )
+ };
+ let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) };
+ let kind = unsafe {
+ LLVMGetMDKindIDInContext(
+ self.context,
+ "fpmath".as_ptr().cast(),
+ "fpmath".len() as u32,
+ )
+ };
+ unsafe { LLVMSetMetadata(fdiv, kind, md_node) };
+ }
+ Ok(())
+ }
+
+ /*
+ // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
+ // Should be available in LLVM 19
+ fn with_rounding<T>(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T {
+ let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
+ let void_type = unsafe { LLVMVoidTypeInContext(self.context) };
+ let get_rounding = c"llvm.get.rounding";
+ let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) };
+ let mut get_rounding_fn =
+ unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) };
+ if get_rounding_fn == ptr::null_mut() {
+ get_rounding_fn = unsafe {
+ LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type)
+ };
+ }
+ let set_rounding = c"llvm.set.rounding";
+ let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) };
+ let mut set_rounding_fn =
+ unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) };
+ if set_rounding_fn == ptr::null_mut() {
+ set_rounding_fn = unsafe {
+ LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type)
+ };
+ }
+ let mut preserved_rounding_mode = unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ get_rounding_fn_type,
+ get_rounding_fn,
+ ptr::null_mut(),
+ 0,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let mut requested_rounding = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::B32),
+ rounding_to_llvm(rnd) as u64,
+ 0,
+ )
+ };
+ unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ set_rounding_fn_type,
+ set_rounding_fn,
+ &mut requested_rounding,
+ 1,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let result = fn_(self);
+ unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ set_rounding_fn_type,
+ set_rounding_fn,
+ &mut preserved_rounding_mode,
+ 1,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ result
+ }
+ */
}
fn get_pointer_type<'ctx>(
@@ -1279,3 +1426,36 @@ impl ResolveIdent {
}
}
}
+
+/*
+struct ScalarTypeInLLVM(ast::ScalarType);
+
+impl std::fmt::Display for ScalarTypeInLLVM {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self.0 {
+ ast::ScalarType::Pred => write!(f, "i1"),
+ ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
+ ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"),
+ ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
+ ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
+ ptx_parser::ScalarType::B128 => write!(f, "i128"),
+ ast::ScalarType::F16 => write!(f, "f16"),
+ ptx_parser::ScalarType::BF16 => write!(f, "bfloat"),
+ ast::ScalarType::F32 => write!(f, "f32"),
+ ast::ScalarType::F64 => write!(f, "f64"),
+ ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"),
+ ast::ScalarType::F16x2 => write!(f, "v2f16"),
+ ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"),
+ }
+ }
+}
+
+fn rounding_to_llvm(this: ast::RoundingMode) -> u32 {
+ match this {
+ ptx_parser::RoundingMode::Zero => 0,
+ ptx_parser::RoundingMode::NearestEven => 1,
+ ptx_parser::RoundingMode::PositiveInf => 2,
+ ptx_parser::RoundingMode::NegativeInf => 3,
+ }
+}
+*/