diff options
author | Andrzej Janik <[email protected]> | 2024-10-16 03:12:54 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-10-16 03:12:54 +0200 |
commit | 73eb31fec5f0dc73c5f43f99d21c757f3acb26cc (patch) | |
tree | 40c13e1b50781ab1f6121da7bf0b856546b74677 /ptx/src/pass/emit_llvm.rs | |
parent | 002a19354aa739a489cec82646dcc2def6893e0d (diff) | |
download | ZLUDA-73eb31fec5f0dc73c5f43f99d21c757f3acb26cc.tar.gz ZLUDA-73eb31fec5f0dc73c5f43f99d21c757f3acb26cc.zip |
Add saturated integer conversions
Diffstat (limited to 'ptx/src/pass/emit_llvm.rs')
-rw-r--r-- | ptx/src/pass/emit_llvm.rs | 131 |
1 files changed, 125 insertions, 6 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index cc40410..d9d2745 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -1533,8 +1533,12 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
- ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
- ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
+ ptx_parser::CvtMode::SaturateUnsignedToSigned => {
+ return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments)
+ }
+ ptx_parser::CvtMode::SaturateSignedToUnsigned => {
+ return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments)
+ }
ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt,
ptx_parser::CvtMode::FPTruncate {
rounding,
@@ -1543,7 +1547,15 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::CvtMode::FPRound {
integer_rounding,
flush_to_zero,
- } => todo!(),
+ } => {
+ return self.emit_cvt_float_to_int(
+ data.from,
+ data.to,
+ integer_rounding.unwrap_or(ast::RoundingMode::NearestEven),
+ arguments,
+ Some(LLVMBuildFPToSI),
+ )
+ }
ptx_parser::CvtMode::SignedFromFP {
rounding,
flush_to_zero,
@@ -1553,7 +1565,7 @@ impl<'a> MethodEmitContext<'a> { data.to,
rounding,
arguments,
- "llvm.fptosi.sat",
+ Some(LLVMBuildFPToSI),
)
}
ptx_parser::CvtMode::UnsignedFromFP {
@@ -1565,7 +1577,7 @@ impl<'a> MethodEmitContext<'a> { data.to,
rounding,
arguments,
- "llvm.fptoui.sat",
+ Some(LLVMBuildFPToUI),
)
}
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
@@ -1578,13 +1590,105 @@ impl<'a> MethodEmitContext<'a> { Ok(())
}
+ fn emit_cvt_unsigned_to_signed_sat(
+ &mut self,
+ from: ptx_parser::ScalarType,
+ to: ptx_parser::ScalarType,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ // This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1,
+ // so if it's downcast to a smaller type, it will be the maximum value
+ // of the smaller type
+ let max_value = match to {
+ ptx_parser::ScalarType::S8 => i8::MAX as u64,
+ ptx_parser::ScalarType::S16 => i16::MAX as u64,
+ ptx_parser::ScalarType::S32 => i32::MAX as u64,
+ ptx_parser::ScalarType::S64 => i64::MAX as u64,
+ _ => return Err(error_unreachable()),
+ };
+ let from_llvm = get_scalar_type(self.context, from);
+ let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
+ let clamped = self.emit_intrinsic(
+ c"llvm.umin",
+ None,
+ &from.into(),
+ vec![
+ (self.resolver.value(arguments.src)?, from_llvm),
+ (max, from_llvm),
+ ],
+ )?;
+ let resize_fn = if to.layout().size() >= from.layout().size() {
+ LLVMBuildSExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ };
+ let to_llvm = get_scalar_type(self.context, to);
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ resize_fn(self.builder, clamped, to_llvm, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_cvt_signed_to_unsigned_sat(
+ &mut self,
+ from: ptx_parser::ScalarType,
+ to: ptx_parser::ScalarType,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let from_llvm = get_scalar_type(self.context, from);
+ let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) };
+ let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
+ let zero_clamped = self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
+ None,
+ &from.into(),
+ vec![
+ (self.resolver.value(arguments.src)?, from_llvm),
+ (zero, from_llvm),
+ ],
+ )?;
+ // zero_clamped is now unsigned
+ let max_value = match to {
+ ptx_parser::ScalarType::U8 => u8::MAX as u64,
+ ptx_parser::ScalarType::U16 => u16::MAX as u64,
+ ptx_parser::ScalarType::U32 => u32::MAX as u64,
+ ptx_parser::ScalarType::U64 => u64::MAX as u64,
+ _ => return Err(error_unreachable()),
+ };
+ let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
+ let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
+ let fully_clamped = self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
+ None,
+ &from.into(),
+ vec![(zero_clamped, from_llvm), (max, from_llvm)],
+ )?;
+ let resize_fn = if to.layout().size() >= from.layout().size() {
+ LLVMBuildZExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ };
+ let to_llvm = get_scalar_type(self.context, to);
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ resize_fn(self.builder, fully_clamped, to_llvm, dst)
+ });
+ Ok(())
+ }
+
fn emit_cvt_float_to_int(
&mut self,
from: ast::ScalarType,
to: ast::ScalarType,
rounding: ast::RoundingMode,
arguments: ptx_parser::CvtArgs<SpirvWord>,
- llvm_cast: &str,
+ llvm_cast: Option<
+ unsafe extern "C" fn(
+ arg1: LLVMBuilderRef,
+ Val: LLVMValueRef,
+ DestTy: LLVMTypeRef,
+ Name: *const i8,
+ ) -> LLVMValueRef,
+ >,
) -> Result<(), TranslateError> {
let prefix = match rounding {
ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
@@ -1602,6 +1706,20 @@ impl<'a> MethodEmitContext<'a> { get_scalar_type(self.context, from),
)],
)?;
+ if let Some(llvm_cast) = llvm_cast {
+ let to = get_scalar_type(self.context, to);
+ let poisoned_dst =
+ unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFreeze(self.builder, poisoned_dst, dst)
+ });
+ } else {
+ self.resolver.register(arguments.dst, rounded_float);
+ }
+ // Using explicit saturation gives us worse codegen: it explicitly checks for out of bound
+ // values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt_<TO>_<FROM> which
+ // saturates by default and we don't care about NaNs anyway
+ /*
let cast_intrinsic = format!(
"{}.{}.{}\0",
llvm_cast,
@@ -1614,6 +1732,7 @@ impl<'a> MethodEmitContext<'a> { &to.into(),
vec![(rounded_float, get_scalar_type(self.context, from))],
)?;
+ */
Ok(())
}
|