From 73eb31fec5f0dc73c5f43f99d21c757f3acb26cc Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 16 Oct 2024 03:12:54 +0200 Subject: Add saturated integer conversions --- ptx/src/pass/emit_llvm.rs | 131 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 125 insertions(+), 6 deletions(-) diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index cc40410..d9d2745 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -1533,8 +1533,12 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::CvtMode::SignExtend => LLVMBuildSExt, ptx_parser::CvtMode::Truncate => LLVMBuildTrunc, ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast, - ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(), - ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(), + ptx_parser::CvtMode::SaturateUnsignedToSigned => { + return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments) + } + ptx_parser::CvtMode::SaturateSignedToUnsigned => { + return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments) + } ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt, ptx_parser::CvtMode::FPTruncate { rounding, @@ -1543,7 +1547,15 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::CvtMode::FPRound { integer_rounding, flush_to_zero, - } => todo!(), + } => { + return self.emit_cvt_float_to_int( + data.from, + data.to, + integer_rounding.unwrap_or(ast::RoundingMode::NearestEven), + arguments, + Some(LLVMBuildFPToSI), + ) + } ptx_parser::CvtMode::SignedFromFP { rounding, flush_to_zero, @@ -1553,7 +1565,7 @@ impl<'a> MethodEmitContext<'a> { data.to, rounding, arguments, - "llvm.fptosi.sat", + Some(LLVMBuildFPToSI), ) } ptx_parser::CvtMode::UnsignedFromFP { @@ -1565,7 +1577,7 @@ impl<'a> MethodEmitContext<'a> { data.to, rounding, arguments, - "llvm.fptoui.sat", + Some(LLVMBuildFPToUI), ) } ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(), @@ -1578,13 +1590,105 @@ impl<'a> MethodEmitContext<'a> { Ok(()) } + fn emit_cvt_unsigned_to_signed_sat( + &mut self, + from: ptx_parser::ScalarType, + to: ptx_parser::ScalarType, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + // This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1, + // so if it's downcast to a smaller type, it will be the maximum value + // of the smaller type + let max_value = match to { + ptx_parser::ScalarType::S8 => i8::MAX as u64, + ptx_parser::ScalarType::S16 => i16::MAX as u64, + ptx_parser::ScalarType::S32 => i32::MAX as u64, + ptx_parser::ScalarType::S64 => i64::MAX as u64, + _ => return Err(error_unreachable()), + }; + let from_llvm = get_scalar_type(self.context, from); + let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; + let clamped = self.emit_intrinsic( + c"llvm.umin", + None, + &from.into(), + vec![ + (self.resolver.value(arguments.src)?, from_llvm), + (max, from_llvm), + ], + )?; + let resize_fn = if to.layout().size() >= from.layout().size() { + LLVMBuildSExtOrBitCast + } else { + LLVMBuildTrunc + }; + let to_llvm = get_scalar_type(self.context, to); + self.resolver.with_result(arguments.dst, |dst| unsafe { + resize_fn(self.builder, clamped, to_llvm, dst) + }); + Ok(()) + } + + fn emit_cvt_signed_to_unsigned_sat( + &mut self, + from: ptx_parser::ScalarType, + to: ptx_parser::ScalarType, + arguments: ptx_parser::CvtArgs, + ) -> Result<(), TranslateError> { + let from_llvm = get_scalar_type(self.context, from); + let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) }; + let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from)); + let zero_clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) }, + None, + &from.into(), + vec![ + (self.resolver.value(arguments.src)?, from_llvm), + (zero, from_llvm), + ], + )?; + // zero_clamped is now unsigned + let max_value = match to { + ptx_parser::ScalarType::U8 => u8::MAX as u64, + ptx_parser::ScalarType::U16 => u16::MAX as u64, + ptx_parser::ScalarType::U32 => u32::MAX as u64, + ptx_parser::ScalarType::U64 => u64::MAX as u64, + _ => return Err(error_unreachable()), + }; + let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; + let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from)); + let fully_clamped = self.emit_intrinsic( + unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) }, + None, + &from.into(), + vec![(zero_clamped, from_llvm), (max, from_llvm)], + )?; + let resize_fn = if to.layout().size() >= from.layout().size() { + LLVMBuildZExtOrBitCast + } else { + LLVMBuildTrunc + }; + let to_llvm = get_scalar_type(self.context, to); + self.resolver.with_result(arguments.dst, |dst| unsafe { + resize_fn(self.builder, fully_clamped, to_llvm, dst) + }); + Ok(()) + } + fn emit_cvt_float_to_int( &mut self, from: ast::ScalarType, to: ast::ScalarType, rounding: ast::RoundingMode, arguments: ptx_parser::CvtArgs, - llvm_cast: &str, + llvm_cast: Option< + unsafe extern "C" fn( + arg1: LLVMBuilderRef, + Val: LLVMValueRef, + DestTy: LLVMTypeRef, + Name: *const i8, + ) -> LLVMValueRef, + >, ) -> Result<(), TranslateError> { let prefix = match rounding { ptx_parser::RoundingMode::NearestEven => "llvm.roundeven", @@ -1602,6 +1706,20 @@ impl<'a> MethodEmitContext<'a> { get_scalar_type(self.context, from), )], )?; + if let Some(llvm_cast) = llvm_cast { + let to = get_scalar_type(self.context, to); + let poisoned_dst = + unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) }; + self.resolver.with_result(arguments.dst, |dst| unsafe { + LLVMBuildFreeze(self.builder, poisoned_dst, dst) + }); + } else { + self.resolver.register(arguments.dst, rounded_float); + } + // Using explicit saturation gives us worse codegen: it explicitly checks for out of bound + // values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt__ which + // saturates by default and we don't care about NaNs anyway + /* let cast_intrinsic = format!( "{}.{}.{}\0", llvm_cast, @@ -1614,6 +1732,7 @@ impl<'a> MethodEmitContext<'a> { &to.into(), vec![(rounded_float, get_scalar_type(self.context, from))], )?; + */ Ok(()) } -- cgit v1.2.3