aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-10-14 19:09:47 +0200
committerAndrzej Janik <[email protected]>2024-10-14 19:09:47 +0200
commitae42eac925201578d74ed3f49e380d07f6f7d0ed (patch)
tree17bee8312d2ff5de66df061e05575986b2bf5216 /ptx
parentd9c33ca50559a7c57bb8b6db750fc667adb557a9 (diff)
downloadZLUDA-ae42eac925201578d74ed3f49e380d07f6f7d0ed.tar.gz
ZLUDA-ae42eac925201578d74ed3f49e380d07f6f7d0ed.zip
Add shifts, cvt, rsqrt, sqrt, rcp, more sregs
Diffstat (limited to 'ptx')
-rw-r--r--ptx/lib/zluda_ptx_impl.bcbin4136 -> 4624 bytes
-rw-r--r--ptx/lib/zluda_ptx_impl.cpp18
-rw-r--r--ptx/src/pass/emit_llvm.rs218
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/shl_link_hack.ptx30
5 files changed, 230 insertions, 37 deletions
diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc
index 9533233..6651430 100644
--- a/ptx/lib/zluda_ptx_impl.bc
+++ b/ptx/lib/zluda_ptx_impl.bc
Binary files differ
diff --git a/ptx/lib/zluda_ptx_impl.cpp b/ptx/lib/zluda_ptx_impl.cpp
index 85823b4..f1b416d 100644
--- a/ptx/lib/zluda_ptx_impl.cpp
+++ b/ptx/lib/zluda_ptx_impl.cpp
@@ -13,12 +13,30 @@ extern "C"
return __builtin_amdgcn_read_exec_lo();
}
+ size_t __ockl_get_local_id(uint32_t) __device__;
+ uint32_t FUNC(sreg_tid)(uint8_t member)
+ {
+ return (uint32_t)__ockl_get_local_id(member);
+ }
+
size_t __ockl_get_local_size(uint32_t) __device__;
uint32_t FUNC(sreg_ntid)(uint8_t member)
{
return (uint32_t)__ockl_get_local_size(member);
}
+ size_t __ockl_get_global_id(uint32_t) __device__;
+ uint32_t FUNC(sreg_ctaid)(uint8_t member)
+ {
+ return (uint32_t)__ockl_get_global_id(member);
+ }
+
+ size_t __ockl_get_global_size(uint32_t) __device__;
+ uint32_t FUNC(sreg_nctaid)(uint8_t member)
+ {
+ return (uint32_t)__ockl_get_global_size(member);
+ }
+
uint32_t __ockl_bfe_u32(uint32_t, uint32_t, uint32_t) __attribute__((device));
uint32_t FUNC(bfe_u32)(uint32_t base, uint32_t pos_32, uint32_t len_32)
{
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs
index ce1eb84..1784745 100644
--- a/ptx/src/pass/emit_llvm.rs
+++ b/ptx/src/pass/emit_llvm.rs
@@ -522,9 +522,9 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
- ast::Instruction::Cvt { .. } => todo!(),
- ast::Instruction::Shr { .. } => todo!(),
- ast::Instruction::Shl { .. } => todo!(),
+ ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments),
+ ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments),
+ ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
ast::Instruction::Abs { .. } => todo!(),
@@ -533,9 +533,9 @@ impl<'a> MethodEmitContext<'a> {
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
ast::Instruction::Min { .. } => todo!(),
ast::Instruction::Max { .. } => todo!(),
- ast::Instruction::Rcp { .. } => todo!(),
- ast::Instruction::Sqrt { .. } => todo!(),
- ast::Instruction::Rsqrt { .. } => todo!(),
+ ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments),
+ ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
+ ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
ast::Instruction::Selp { .. } => todo!(),
ast::Instruction::Bar { .. } => todo!(),
ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
@@ -1406,6 +1406,212 @@ impl<'a> MethodEmitContext<'a> {
Ok(())
}
+ fn emit_cvt(
+ &mut self,
+ data: ptx_parser::CvtDetails,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let dst_type = get_scalar_type(self.context, data.to);
+ let llvm_fn = match data.mode {
+ ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
+ ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
+ ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
+ ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
+ ptx_parser::CvtMode::SaturateUnsignedToSigned => todo!(),
+ ptx_parser::CvtMode::SaturateSignedToUnsigned => todo!(),
+ ptx_parser::CvtMode::FPExtend { flush_to_zero } => todo!(),
+ ptx_parser::CvtMode::FPTruncate {
+ rounding,
+ flush_to_zero,
+ } => todo!(),
+ ptx_parser::CvtMode::FPRound {
+ integer_rounding,
+ flush_to_zero,
+ } => todo!(),
+ ptx_parser::CvtMode::SignedFromFP {
+ rounding,
+ flush_to_zero,
+ } => todo!(),
+ ptx_parser::CvtMode::UnsignedFromFP {
+ rounding,
+ flush_to_zero,
+ } => todo!(),
+ ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(),
+ ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(),
+ };
+ let src = self.resolver.value(arguments.src)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ llvm_fn(self.builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_rsqrt(
+ &mut self,
+ data: ptx_parser::TypeFtz,
+ arguments: ptx_parser::RsqrtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let intrinsic = match data.type_ {
+ ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32",
+ ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64",
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(arguments.src, type_)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_sqrt(
+ &mut self,
+ data: ptx_parser::RcpData,
+ arguments: ptx_parser::SqrtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let intrinsic = match (data.type_, data.kind) {
+ (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32",
+ (ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32",
+ (ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64",
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(arguments.src, type_)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_rcp(
+ &mut self,
+ data: ptx_parser::RcpData,
+ arguments: ptx_parser::RcpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let intrinsic = match (data.type_, data.kind) {
+ (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32",
+ (_, ast::RcpKind::Compliant(rnd)) => {
+ return self.emit_rcp_compliant(data, arguments, rnd)
+ }
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(arguments.src, type_)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_rcp_compliant(
+ &mut self,
+ data: ptx_parser::RcpData,
+ arguments: ptx_parser::RcpArgs<SpirvWord>,
+ _rnd: ast::RoundingMode,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let one = unsafe { LLVMConstReal(type_, 1.0) };
+ let src = self.resolver.value(arguments.src)?;
+ let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFDiv(self.builder, one, src, dst)
+ });
+ unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) };
+ Ok(())
+ }
+
+ fn emit_shr(
+ &mut self,
+ data: ptx_parser::ShrData,
+ arguments: ptx_parser::ShrArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let shift_fn = match data.kind {
+ ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr,
+ ptx_parser::RightShiftKind::Logical => LLVMBuildLShr,
+ };
+ self.emit_shift(
+ data.type_,
+ arguments.dst,
+ arguments.src1,
+ arguments.src2,
+ shift_fn,
+ )
+ }
+
+ fn emit_shl(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ arguments: ptx_parser::ShlArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.emit_shift(
+ type_,
+ arguments.dst,
+ arguments.src1,
+ arguments.src2,
+ LLVMBuildShl,
+ )
+ }
+
+ fn emit_shift(
+ &mut self,
+ type_: ast::ScalarType,
+ dst: SpirvWord,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ llvm_fn: unsafe extern "C" fn(
+ LLVMBuilderRef,
+ LLVMValueRef,
+ LLVMValueRef,
+ *const i8,
+ ) -> LLVMValueRef,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(src1)?;
+ let shift_size = self.resolver.value(src2)?;
+ let integer_bits = type_.layout().size() * 8;
+ let integer_bits_constant = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::U32),
+ integer_bits as u64,
+ 0,
+ )
+ };
+ let should_clamp = unsafe {
+ LLVMBuildICmp(
+ self.builder,
+ LLVMIntPredicate::LLVMIntUGE,
+ shift_size,
+ integer_bits_constant,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let llvm_type = get_scalar_type(self.context, type_);
+ let zero = unsafe { LLVMConstNull(llvm_type) };
+ let normalized_shift_size = if type_.layout().size() >= 4 {
+ unsafe {
+ LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr())
+ }
+ } else {
+ unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) }
+ };
+ let shifted = unsafe {
+ llvm_fn(
+ self.builder,
+ src1,
+ normalized_shift_size,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ self.resolver.with_result(dst, |dst| unsafe {
+ LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst)
+ });
+ Ok(())
+ }
+
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 1b5afee..f4b7921 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -45,7 +45,6 @@ test_ptx!(setp_leu, [1f32, f32::NAN], [1f32]);
test_ptx!(bra, [10u64], [11u64]);
test_ptx!(not, [0u64], [u64::max_value()]);
test_ptx!(shl, [11u64], [44u64]);
-test_ptx!(shl_link_hack, [11u64], [44u64]);
test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
test_ptx!(cvta, [3.0f32], [3.0f32]);
test_ptx!(block, [1u64], [2u64]);
diff --git a/ptx/src/test/spirv_run/shl_link_hack.ptx b/ptx/src/test/spirv_run/shl_link_hack.ptx
deleted file mode 100644
index a32555c..0000000
--- a/ptx/src/test/spirv_run/shl_link_hack.ptx
+++ /dev/null
@@ -1,30 +0,0 @@
-// HACK ALERT
-// This test is for testing workaround for a bug in IGC where linking fails
-// if there is shl/shr with different width of value and shift
-
-.version 6.5
-.target sm_30
-.address_size 64
-
-.visible .entry shl_link_hack(
- .param .u64 input,
- .param .u64 output
-)
-{
- .reg .u64 in_addr;
- .reg .u64 out_addr;
- .reg .u64 temp;
- .reg .u64 temp2;
-
- ld.param.u64 in_addr, [input];
- ld.param.u64 out_addr, [output];
-
- // Here only to trigger linking
- .reg .u32 unused;
- atom.inc.u32 unused, [out_addr], 2000000;
-
- ld.u64 temp, [in_addr];
- shl.b64 temp2, temp, 2;
- st.u64 [out_addr], temp2;
- ret;
-}