diff options
author | Andrzej Janik <[email protected]> | 2024-10-15 19:16:11 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-10-15 19:16:11 +0200 |
commit | 002a19354aa739a489cec82646dcc2def6893e0d (patch) | |
tree | 60ff643f4c164c25fda535170e47b14fa3c44a27 | |
parent | 3105674618b790214ab629bd28162bcc27d8827a (diff) | |
download | ZLUDA-002a19354aa739a489cec82646dcc2def6893e0d.tar.gz ZLUDA-002a19354aa739a489cec82646dcc2def6893e0d.zip |
Add float-to-int cvt
-rw-r--r-- | ptx/src/pass/emit_llvm.rs | 124 |
1 files changed, 97 insertions, 27 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 54a07aa..cc40410 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -1118,7 +1118,7 @@ impl<'a> MethodEmitContext<'a> { c"llvm.cos.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
- vec![(arguments.src, llvm_f32)],
+ vec![(self.resolver.value(arguments.src)?, llvm_f32)],
)?;
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
Ok(())
@@ -1371,7 +1371,7 @@ impl<'a> MethodEmitContext<'a> { c"llvm.sin.f32",
Some(arguments.dst),
&ast::ScalarType::F32.into(),
- vec![(arguments.src, llvm_f32)],
+ vec![(self.resolver.value(arguments.src)?, llvm_f32)],
)?;
unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
Ok(())
@@ -1382,7 +1382,7 @@ impl<'a> MethodEmitContext<'a> { name: &CStr,
dst: Option<SpirvWord>,
return_type: &ast::Type,
- arguments: Vec<(SpirvWord, LLVMTypeRef)>,
+ arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
) -> Result<LLVMValueRef, TranslateError> {
let fn_type = get_function_type(
self.context,
@@ -1393,10 +1393,7 @@ impl<'a> MethodEmitContext<'a> { if fn_ == ptr::null_mut() {
fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
}
- let mut arguments = arguments
- .iter()
- .map(|(arg, _)| self.resolver.value(*arg))
- .collect::<Result<Vec<_>, _>>()?;
+ let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::<Vec<_>>();
Ok(self.resolver.with_result_option(dst, |dst| unsafe {
LLVMBuildCall2(
self.builder,
@@ -1538,11 +1535,11 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
- ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(),
+ ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt,
ptx_parser::CvtMode::FPTruncate {
rounding,
flush_to_zero,
- } => todo!(),
+ } => LLVMBuildFPTrunc,
ptx_parser::CvtMode::FPRound {
integer_rounding,
flush_to_zero,
@@ -1550,11 +1547,27 @@ impl<'a> MethodEmitContext<'a> { ptx_parser::CvtMode::SignedFromFP {
rounding,
flush_to_zero,
- } => todo!(),
+ } => {
+ return self.emit_cvt_float_to_int(
+ data.from,
+ data.to,
+ rounding,
+ arguments,
+ "llvm.fptosi.sat",
+ )
+ }
ptx_parser::CvtMode::UnsignedFromFP {
rounding,
flush_to_zero,
- } => todo!(),
+ } => {
+ return self.emit_cvt_float_to_int(
+ data.from,
+ data.to,
+ rounding,
+ arguments,
+ "llvm.fptoui.sat",
+ )
+ }
ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(),
};
@@ -1565,6 +1578,45 @@ impl<'a> MethodEmitContext<'a> { Ok(())
}
+ fn emit_cvt_float_to_int(
+ &mut self,
+ from: ast::ScalarType,
+ to: ast::ScalarType,
+ rounding: ast::RoundingMode,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ llvm_cast: &str,
+ ) -> Result<(), TranslateError> {
+ let prefix = match rounding {
+ ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
+ ptx_parser::RoundingMode::Zero => "llvm.trunc",
+ ptx_parser::RoundingMode::NegativeInf => "llvm.floor",
+ ptx_parser::RoundingMode::PositiveInf => "llvm.ceil",
+ };
+ let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from));
+ let rounded_float = self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ None,
+ &from.into(),
+ vec![(
+ self.resolver.value(arguments.src)?,
+ get_scalar_type(self.context, from),
+ )],
+ )?;
+ let cast_intrinsic = format!(
+ "{}.{}.{}\0",
+ llvm_cast,
+ LLVMTypeDisplay(to),
+ LLVMTypeDisplay(from)
+ );
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &to.into(),
+ vec![(rounded_float, get_scalar_type(self.context, from))],
+ )?;
+ Ok(())
+ }
+
fn emit_rsqrt(
&mut self,
data: ptx_parser::TypeFtz,
@@ -1580,7 +1632,7 @@ impl<'a> MethodEmitContext<'a> { intrinsic,
Some(arguments.dst),
&data.type_.into(),
- vec![(arguments.src, type_)],
+ vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
}
@@ -1601,7 +1653,7 @@ impl<'a> MethodEmitContext<'a> { intrinsic,
Some(arguments.dst),
&data.type_.into(),
- vec![(arguments.src, type_)],
+ vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
}
@@ -1623,7 +1675,7 @@ impl<'a> MethodEmitContext<'a> { intrinsic,
Some(arguments.dst),
&data.type_.into(),
- vec![(arguments.src, type_)],
+ vec![(self.resolver.value(arguments.src)?, type_)],
)?;
Ok(())
}
@@ -1745,7 +1797,10 @@ impl<'a> MethodEmitContext<'a> { intrinsic,
Some(arguments.dst),
&data.type_.into(),
- vec![(arguments.src, get_scalar_type(self.context, data.type_))],
+ vec![(
+ self.resolver.value(arguments.src)?,
+ get_scalar_type(self.context, data.type_),
+ )],
)?;
Ok(())
}
@@ -1760,7 +1815,7 @@ impl<'a> MethodEmitContext<'a> { Some(arguments.dst),
&ast::ScalarType::F32.into(),
vec![(
- arguments.src,
+ self.resolver.value(arguments.src)?,
get_scalar_type(self.context, ast::ScalarType::F32.into()),
)],
)?;
@@ -1814,7 +1869,7 @@ impl<'a> MethodEmitContext<'a> { intrinsic,
Some(arguments.dst),
&type_.into(),
- vec![(arguments.src, llvm_type)],
+ vec![(self.resolver.value(arguments.src)?, llvm_type)],
)?;
Ok(())
}
@@ -1832,13 +1887,16 @@ impl<'a> MethodEmitContext<'a> { }
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
};
- let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
+ let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
- vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
+ vec![
+ (self.resolver.value(arguments.src1)?, llvm_type),
+ (self.resolver.value(arguments.src2)?, llvm_type),
+ ],
)?;
Ok(())
}
@@ -1856,13 +1914,16 @@ impl<'a> MethodEmitContext<'a> { }
ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
};
- let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
+ let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
let llvm_type = get_scalar_type(self.context, data.type_());
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_().into(),
- vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
+ vec![
+ (self.resolver.value(arguments.src1)?, llvm_type),
+ (self.resolver.value(arguments.src2)?, llvm_type),
+ ],
)?;
Ok(())
}
@@ -1872,15 +1933,24 @@ impl<'a> MethodEmitContext<'a> { data: ptx_parser::ArithFloat,
arguments: ptx_parser::FmaArgs<SpirvWord>,
) -> Result<(), TranslateError> {
- let intrinsic = format!("llvm.fma.{}\0", ScalarTypeInLLVM(data.type_));
+ let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_));
self.emit_intrinsic(
unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
Some(arguments.dst),
&data.type_.into(),
vec![
- (arguments.src1, get_scalar_type(self.context, data.type_)),
- (arguments.src2, get_scalar_type(self.context, data.type_)),
- (arguments.src3, get_scalar_type(self.context, data.type_)),
+ (
+ self.resolver.value(arguments.src1)?,
+ get_scalar_type(self.context, data.type_),
+ ),
+ (
+ self.resolver.value(arguments.src2)?,
+ get_scalar_type(self.context, data.type_),
+ ),
+ (
+ self.resolver.value(arguments.src3)?,
+ get_scalar_type(self.context, data.type_),
+ ),
],
)?;
Ok(())
@@ -2238,9 +2308,9 @@ impl ResolveIdent { }
}
-struct ScalarTypeInLLVM(ast::ScalarType);
+struct LLVMTypeDisplay(ast::ScalarType);
-impl std::fmt::Display for ScalarTypeInLLVM {
+impl std::fmt::Display for LLVMTypeDisplay {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0 {
ast::ScalarType::Pred => write!(f, "i1"),
|