aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-10-15 04:42:44 +0200
committerAndrzej Janik <[email protected]>2024-10-15 04:42:44 +0200
commit6f2944d9be9e5cdf54fe4f52c948161cc10eb94d (patch)
tree31293f91f437a239f70bb4b3bcaff5a41f5f0e99
parentae42eac925201578d74ed3f49e380d07f6f7d0ed (diff)
downloadZLUDA-6f2944d9be9e5cdf54fe4f52c948161cc10eb94d.tar.gz
ZLUDA-6f2944d9be9e5cdf54fe4f52c948161cc10eb94d.zip
Add or, mad, fma, min, max, selp, lg2, ex2, popc, rem
-rw-r--r--ptx/src/pass/emit_llvm.rs309
-rw-r--r--ptx/src/pass/mod.rs10
2 files changed, 298 insertions, 21 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs
index 1784745..209840f 100644
--- a/ptx/src/pass/emit_llvm.rs
+++ b/ptx/src/pass/emit_llvm.rs
@@ -518,7 +518,7 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
ast::Instruction::SetpBool { .. } => todo!(),
ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
- ast::Instruction::Or { .. } => todo!(),
+ ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments),
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
@@ -528,15 +528,15 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
ast::Instruction::Abs { .. } => todo!(),
- ast::Instruction::Mad { .. } => todo!(),
- ast::Instruction::Fma { .. } => todo!(),
+ ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments),
+ ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments),
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
- ast::Instruction::Min { .. } => todo!(),
- ast::Instruction::Max { .. } => todo!(),
+ ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments),
+ ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments),
ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments),
ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
- ast::Instruction::Selp { .. } => todo!(),
+ ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments),
ast::Instruction::Bar { .. } => todo!(),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
@@ -544,13 +544,13 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments),
ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments),
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
- ast::Instruction::Lg2 { .. } => todo!(),
- ast::Instruction::Ex2 { .. } => todo!(),
+ ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments),
+ ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments),
ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments),
ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
- ast::Instruction::Popc { .. } => todo!(),
+ ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments),
ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
- ast::Instruction::Rem { .. } => todo!(),
+ ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
ast::Instruction::PrmtSlow { .. } => todo!(),
ast::Instruction::Prmt { .. } => todo!(),
ast::Instruction::Membar { .. } => todo!(),
@@ -664,7 +664,14 @@ impl<'a> MethodEmitContext<'a> {
_ => todo!(),
}
}
- ConversionKind::SignExtend => todo!(),
+ ConversionKind::SignExtend => {
+ let src = self.resolver.value(conversion.src)?;
+ let type_ = get_type(self.context, &conversion.to_type)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildSExt(builder, src, type_, dst)
+ });
+ Ok(())
+ }
ConversionKind::BitToPtr => {
let src = self.resolver.value(conversion.src)?;
let type_ = get_pointer_type(self.context, conversion.to_space)?;
@@ -986,20 +993,82 @@ impl<'a> MethodEmitContext<'a> {
data: ast::MulDetails,
arguments: ast::MulArgs<SpirvWord>,
) -> Result<(), TranslateError> {
+ self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?;
+ Ok(())
+ }
+
+ fn emit_mul_impl(
+ &mut self,
+ data: ast::MulDetails,
+ dst: Option<SpirvWord>,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ ) -> Result<LLVMValueRef, TranslateError> {
let mul_fn = match data {
- ast::MulDetails::Integer { control, .. } => match control {
+ ast::MulDetails::Integer { control, type_ } => match control {
ast::MulIntControl::Low => LLVMBuildMul,
- ast::MulIntControl::High => todo!(),
- ast::MulIntControl::Wide => todo!(),
+ ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2),
+ ast::MulIntControl::Wide => {
+ return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1)
+ }
},
ast::MulDetails::Float(..) => LLVMBuildFMul,
};
- let src1 = self.resolver.value(arguments.src1)?;
- let src2 = self.resolver.value(arguments.src2)?;
- self.resolver.with_result(arguments.dst, |dst| unsafe {
- mul_fn(self.builder, src1, src2, dst)
- });
- Ok(())
+ let src1 = self.resolver.value(src1)?;
+ let src2 = self.resolver.value(src2)?;
+ Ok(self
+ .resolver
+ .with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) }))
+ }
+
+ fn emit_mul_high(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ dst: Option<SpirvWord>,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ ) -> Result<LLVMValueRef, TranslateError> {
+ let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?;
+ let shift_constant =
+ unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) };
+ let shifted = unsafe {
+ LLVMBuildLShr(
+ self.builder,
+ wide_value,
+ shift_constant,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let narrow_type = get_scalar_type(self.context, type_);
+ Ok(self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildTrunc(self.builder, shifted, narrow_type, dst)
+ }))
+ }
+
+ fn emit_mul_wide_impl(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ dst: Option<SpirvWord>,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ ) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> {
+ let src1 = self.resolver.value(src1)?;
+ let src2 = self.resolver.value(src2)?;
+ let wide_type =
+ unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) };
+ let llvm_cast = match type_.kind() {
+ ptx_parser::ScalarKind::Signed => LLVMBuildSExt,
+ ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt,
+ _ => return Err(error_unreachable()),
+ };
+ let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) };
+ let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) };
+ Ok((
+ wide_type,
+ self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildMul(self.builder, src1, src2, dst)
+ }),
+ ))
}
fn emit_cos(
@@ -1018,6 +1087,19 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
+ fn emit_or(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::OrArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildOr(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
fn emit_xor(
&mut self,
_data: ptx_parser::ScalarType,
@@ -1612,6 +1694,191 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
+ fn emit_ex2(
+ &mut self,
+ data: ptx_parser::TypeFtz,
+ arguments: ptx_parser::Ex2Args<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let intrinsic = match data.type_ {
+ ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16",
+ ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32",
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(arguments.src, get_scalar_type(self.context, data.type_))],
+ )?;
+ Ok(())
+ }
+
+ fn emit_lg2(
+ &mut self,
+ _data: ptx_parser::FlushToZero,
+ arguments: ptx_parser::Lg2Args<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.emit_intrinsic(
+ c"llvm.amdgcn.log.f32",
+ Some(arguments.dst),
+ &ast::ScalarType::F32.into(),
+ vec![(
+ arguments.src,
+ get_scalar_type(self.context, ast::ScalarType::F32.into()),
+ )],
+ )?;
+ Ok(())
+ }
+
+ fn emit_selp(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::SelpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let src3 = self.resolver.value(arguments.src3)?;
+ self.resolver.with_result(arguments.dst, |dst_name| unsafe {
+ LLVMBuildSelect(self.builder, src3, src1, src2, dst_name)
+ });
+ Ok(())
+ }
+
+ fn emit_rem(
+ &mut self,
+ data: ptx_parser::ScalarType,
+ arguments: ptx_parser::RemArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_fn = match data.kind() {
+ ptx_parser::ScalarKind::Unsigned => LLVMBuildURem,
+ ptx_parser::ScalarKind::Signed => LLVMBuildSRem,
+ _ => return Err(error_unreachable()),
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst_name| unsafe {
+ llvm_fn(self.builder, src1, src2, dst_name)
+ });
+ Ok(())
+ }
+
+ fn emit_popc(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ arguments: ptx_parser::PopcArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let intrinsic = match type_ {
+ ast::ScalarType::B32 => c"llvm.ctpop.i32",
+ ast::ScalarType::B64 => c"llvm.ctpop.i64",
+ _ => return Err(error_unreachable()),
+ };
+ let llvm_type = get_scalar_type(self.context, type_);
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &type_.into(),
+ vec![(arguments.src, llvm_type)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_min(
+ &mut self,
+ data: ptx_parser::MinMaxDetails,
+ arguments: ptx_parser::MinArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_prefix = match data {
+ ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin",
+ ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
+ return Err(error_todo())
+ }
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
+ };
+ let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
+ let llvm_type = get_scalar_type(self.context, data.type_());
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_().into(),
+ vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_max(
+ &mut self,
+ data: ptx_parser::MinMaxDetails,
+ arguments: ptx_parser::MaxArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_prefix = match data {
+ ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax",
+ ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
+ return Err(error_todo())
+ }
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
+ };
+ let intrinsic = format!("{}.{}\0", llvm_prefix, ScalarTypeInLLVM(data.type_()));
+ let llvm_type = get_scalar_type(self.context, data.type_());
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_().into(),
+ vec![(arguments.src1, llvm_type), (arguments.src2, llvm_type)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_fma(
+ &mut self,
+ data: ptx_parser::ArithFloat,
+ arguments: ptx_parser::FmaArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let intrinsic = format!("llvm.fma.{}\0", ScalarTypeInLLVM(data.type_));
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![
+ (arguments.src1, get_scalar_type(self.context, data.type_)),
+ (arguments.src2, get_scalar_type(self.context, data.type_)),
+ (arguments.src3, get_scalar_type(self.context, data.type_)),
+ ],
+ )?;
+ Ok(())
+ }
+
+ fn emit_mad(
+ &mut self,
+ data: ptx_parser::MadDetails,
+ arguments: ptx_parser::MadArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let mul_control = match data {
+ ptx_parser::MadDetails::Float(mad_float) => {
+ return self.emit_fma(
+ mad_float,
+ ast::FmaArgs {
+ dst: arguments.dst,
+ src1: arguments.src1,
+ src2: arguments.src2,
+ src3: arguments.src3,
+ },
+ )
+ }
+ ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()),
+ ptx_parser::MadDetails::Integer { type_, control, .. } => {
+ ast::MulDetails::Integer { control, type_ }
+ }
+ };
+ let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?;
+ let src3 = self.resolver.value(arguments.src3)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildAdd(self.builder, temp, src3, dst)
+ });
+ Ok(())
+ }
+
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
@@ -1870,7 +2137,6 @@ impl ResolveIdent {
}
}
-/*
struct ScalarTypeInLLVM(ast::ScalarType);
impl std::fmt::Display for ScalarTypeInLLVM {
@@ -1893,6 +2159,7 @@ impl std::fmt::Display for ScalarTypeInLLVM {
}
}
+/*
fn rounding_to_llvm(this: ast::RoundingMode) -> u32 {
match this {
ptx_parser::RoundingMode::Zero => 0,
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs
index 65292eb..ef131b4 100644
--- a/ptx/src/pass/mod.rs
+++ b/ptx/src/pass/mod.rs
@@ -150,6 +150,16 @@ fn error_unreachable() -> TranslateError {
}
#[cfg(debug_assertions)]
+fn error_todo() -> TranslateError {
+ unreachable!()
+}
+
+#[cfg(not(debug_assertions))]
+fn error_todo() -> TranslateError {
+ TranslateError::Todo
+}
+
+#[cfg(debug_assertions)]
fn error_unknown_symbol() -> TranslateError {
panic!()
}