aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-10-11 22:55:10 +0200
committerAndrzej Janik <[email protected]>2024-10-11 22:55:10 +0200
commitd9c33ca50559a7c57bb8b6db750fc667adb557a9 (patch)
tree610bfca9df9aa1736c006a39da8eb69d1b842675
parentc8b88f4483eaf5ee68cd9306ca57dfaa5f7d0ce0 (diff)
downloadZLUDA-d9c33ca50559a7c57bb8b6db750fc667adb557a9.tar.gz
ZLUDA-d9c33ca50559a7c57bb8b6db750fc667adb557a9.zip
Add br, setp, not, cvta, sub, neg, sin
-rw-r--r--ptx/src/pass/emit_llvm.rs281
1 files changed, 259 insertions, 22 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs
index 15177bc..ce1eb84 100644
--- a/ptx/src/pass/emit_llvm.rs
+++ b/ptx/src/pass/emit_llvm.rs
@@ -451,7 +451,7 @@ impl<'a> MethodEmitContext<'a> {
Statement::Variable(var) => self.emit_variable(var)?,
Statement::Label(label) => self.emit_label_delayed(label)?,
Statement::Instruction(inst) => self.emit_instruction(inst)?,
- Statement::Conditional(_) => todo!(),
+ Statement::Conditional(cond) => self.emit_conditional(cond)?,
Statement::Conversion(conversion) => self.emit_conversion(conversion)?,
Statement::Constant(constant) => self.emit_constant(constant)?,
Statement::RetValue(_, values) => self.emit_ret_value(values)?,
@@ -515,9 +515,9 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
- ast::Instruction::Setp { .. } => todo!(),
+ ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
ast::Instruction::SetpBool { .. } => todo!(),
- ast::Instruction::Not { .. } => todo!(),
+ ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
ast::Instruction::Or { .. } => todo!(),
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
@@ -526,11 +526,11 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Shr { .. } => todo!(),
ast::Instruction::Shl { .. } => todo!(),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
- ast::Instruction::Cvta { .. } => todo!(),
+ ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
ast::Instruction::Abs { .. } => todo!(),
ast::Instruction::Mad { .. } => todo!(),
ast::Instruction::Fma { .. } => todo!(),
- ast::Instruction::Sub { .. } => todo!(),
+ ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
ast::Instruction::Min { .. } => todo!(),
ast::Instruction::Max { .. } => todo!(),
ast::Instruction::Rcp { .. } => todo!(),
@@ -541,8 +541,8 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
- ast::Instruction::Neg { .. } => todo!(),
- ast::Instruction::Sin { .. } => todo!(),
+ ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments),
+ ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments),
ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
ast::Instruction::Lg2 { .. } => todo!(),
ast::Instruction::Ex2 { .. } => todo!(),
@@ -651,6 +651,16 @@ impl<'a> MethodEmitContext<'a> {
}
}
}
+ (ast::Type::Vector(..), ast::Type::Scalar(..))
+ | (ast::Type::Scalar(..), ast::Type::Array(..))
+ | (ast::Type::Array(..), ast::Type::Scalar(..)) => {
+ let src = self.resolver.value(conversion.src)?;
+ let dst_type = get_type(self.context, &conversion.to_type)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildBitCast(builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
_ => todo!(),
}
}
@@ -997,20 +1007,13 @@ impl<'a> MethodEmitContext<'a> {
_data: ast::FlushToZero,
arguments: ast::CosArgs<SpirvWord>,
) -> Result<(), TranslateError> {
- let llvm_fn = c"llvm.cos.f32";
- let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
- let fn_type = get_function_type(
- self.context,
- iter::once(&ast::ScalarType::F32.into()),
- iter::once(Ok(get_scalar_type(self.context, ast::ScalarType::F32))),
+ let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
+ let cos = self.emit_intrinsic(
+ c"llvm.cos.f32",
+ Some(arguments.dst),
+ &ast::ScalarType::F32.into(),
+ vec![(arguments.src, llvm_f32)],
)?;
- if fn_ == ptr::null_mut() {
- fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
- }
- let mut src = self.resolver.value(arguments.src)?;
- let cos = self.resolver.with_result(arguments.dst, |dst| unsafe {
- LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
- });
unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
Ok(())
}
@@ -1168,6 +1171,241 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
+ fn emit_cvta(
+ &mut self,
+ data: ptx_parser::CvtaDetails,
+ arguments: ptx_parser::CvtaArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let (from_space, to_space) = match data.direction {
+ ptx_parser::CvtaDirection::GenericToExplicit => {
+ (ast::StateSpace::Generic, data.state_space)
+ }
+ ptx_parser::CvtaDirection::ExplicitToGeneric => {
+ (data.state_space, ast::StateSpace::Generic)
+ }
+ };
+ let from_type = get_pointer_type(self.context, from_space)?;
+ let dest_type = get_pointer_type(self.context, to_space)?;
+ let src = self.resolver.value(arguments.src)?;
+ let temp_ptr =
+ unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_sub(
+ &mut self,
+ data: ptx_parser::ArithDetails,
+ arguments: ptx_parser::SubArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ match data {
+ ptx_parser::ArithDetails::Integer(arith_integer) => {
+ self.emit_sub_integer(arith_integer, arguments)
+ }
+ ptx_parser::ArithDetails::Float(arith_float) => {
+ self.emit_sub_float(arith_float, arguments)
+ }
+ }
+ }
+
+ fn emit_sub_integer(
+ &mut self,
+ arith_integer: ptx_parser::ArithInteger,
+ arguments: ptx_parser::SubArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if arith_integer.saturate {
+ todo!()
+ }
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildSub(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_sub_float(
+ &mut self,
+ arith_float: ptx_parser::ArithFloat,
+ arguments: ptx_parser::SubArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if arith_float.saturate {
+ todo!()
+ }
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFSub(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_sin(
+ &mut self,
+ _data: ptx_parser::FlushToZero,
+ arguments: ptx_parser::SinArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
+ let sin = self.emit_intrinsic(
+ c"llvm.sin.f32",
+ Some(arguments.dst),
+ &ast::ScalarType::F32.into(),
+ vec![(arguments.src, llvm_f32)],
+ )?;
+ unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
+ Ok(())
+ }
+
+ fn emit_intrinsic(
+ &mut self,
+ name: &CStr,
+ dst: Option<SpirvWord>,
+ return_type: &ast::Type,
+ arguments: Vec<(SpirvWord, LLVMTypeRef)>,
+ ) -> Result<LLVMValueRef, TranslateError> {
+ let fn_type = get_function_type(
+ self.context,
+ iter::once(return_type),
+ arguments.iter().map(|(_, type_)| Ok(*type_)),
+ )?;
+ let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
+ if fn_ == ptr::null_mut() {
+ fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ }
+ let mut arguments = arguments
+ .iter()
+ .map(|(arg, _)| self.resolver.value(*arg))
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ fn_type,
+ fn_,
+ arguments.as_mut_ptr(),
+ arguments.len() as u32,
+ dst,
+ )
+ }))
+ }
+
+ fn emit_neg(
+ &mut self,
+ data: ptx_parser::TypeFtz,
+ arguments: ptx_parser::NegArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src = self.resolver.value(arguments.src)?;
+ let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float {
+ LLVMBuildFNeg
+ } else {
+ LLVMBuildNeg
+ };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ llvm_fn(self.builder, src, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_not(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::NotArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src = self.resolver.value(arguments.src)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildNot(self.builder, src, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_setp(
+ &mut self,
+ data: ptx_parser::SetpData,
+ arguments: ptx_parser::SetpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if arguments.dst2.is_some() {
+ todo!()
+ }
+ match data.cmp_op {
+ ptx_parser::SetpCompareOp::Integer(setp_compare_int) => {
+ self.emit_setp_int(setp_compare_int, arguments)
+ }
+ ptx_parser::SetpCompareOp::Float(setp_compare_float) => {
+ self.emit_setp_float(setp_compare_float, arguments)
+ }
+ }
+ }
+
+ fn emit_setp_int(
+ &mut self,
+ setp: ptx_parser::SetpCompareInt,
+ arguments: ptx_parser::SetpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let op = match setp {
+ ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ,
+ ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE,
+ ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT,
+ ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE,
+ ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT,
+ ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE,
+ ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT,
+ ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE,
+ ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT,
+ ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE,
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst1, |dst1| unsafe {
+ LLVMBuildICmp(self.builder, op, src1, src2, dst1)
+ });
+ Ok(())
+ }
+
+ fn emit_setp_float(
+ &mut self,
+ setp: ptx_parser::SetpCompareFloat,
+ arguments: ptx_parser::SetpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let op = match setp {
+ ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ,
+ ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE,
+ ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT,
+ ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE,
+ ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT,
+ ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE,
+ ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ,
+ ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE,
+ ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT,
+ ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE,
+ ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT,
+ ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE,
+ ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD,
+ ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO,
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst1, |dst1| unsafe {
+ LLVMBuildFCmp(self.builder, op, src1, src2, dst1)
+ });
+ Ok(())
+ }
+
+ fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> {
+ let predicate = self.resolver.value(cond.predicate)?;
+ let if_true = self.resolver.value(cond.if_true)?;
+ let if_false = self.resolver.value(cond.if_false)?;
+ unsafe {
+ LLVMBuildCondBr(
+ self.builder,
+ predicate,
+ LLVMValueAsBasicBlock(if_true),
+ LLVMValueAsBasicBlock(if_false),
+ )
+ };
+ Ok(())
+ }
+
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
@@ -1328,8 +1566,7 @@ fn get_function_type<'a>(
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
- let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
- input_args.collect::<Result<Vec<_>, _>>()?;
+ let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,