aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-02-15 15:07:05 +0000
committerAndrzej Janik <[email protected]>2024-02-15 15:07:05 +0000
commit8fef0e4fe7e6ee6d1f2b556694faffee0c6b9648 (patch)
tree3ee01dcf40d9fa2dbcb6110c5685d5fb8fcd7be9
parent8d10f756a949b9759cfdd90d7bcc9a9381bce56b (diff)
downloadZLUDA-8fef0e4fe7e6ee6d1f2b556694faffee0c6b9648.tar.gz
ZLUDA-8fef0e4fe7e6ee6d1f2b556694faffee0c6b9648.zip
Support sign extending in prmt
-rw-r--r--ptx/src/emit.rs36
-rw-r--r--ptx/src/ptx.lalrpop9
-rw-r--r--ptx/src/test/spirv_run/mod.rs2
-rw-r--r--ptx/src/test/spirv_run/prmt.ll76
-rw-r--r--ptx/src/test/spirv_run/prmt.ptx8
5 files changed, 91 insertions, 40 deletions
diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs
index 89fce11..ac054d4 100644
--- a/ptx/src/emit.rs
+++ b/ptx/src/emit.rs
@@ -1654,14 +1654,17 @@ fn emit_inst_prmt(
) -> Result<(), TranslateError> {
let builder = ctx.builder.get();
let components = [
- ((control >> 0) & 0b1111) as u32,
- ((control >> 4) & 0b1111) as u32,
- ((control >> 8) & 0b1111) as u32,
- ((control >> 12) & 0b1111) as u32,
+ ((control >> 0) & 0b0111) as u32,
+ ((control >> 4) & 0b0111) as u32,
+ ((control >> 8) & 0b0111) as u32,
+ ((control >> 12) & 0b0111) as u32,
+ ];
+ let sext_components = [
+ ((control >> 0) & 0b1000) != 0,
+ ((control >> 4) & 0b1000) != 0,
+ ((control >> 8) & 0b1000) != 0,
+ ((control >> 12) & 0b1000) != 0,
];
- if components.iter().any(|&c| c > 7) {
- return Err(TranslateError::todo());
- }
let llvm_i32 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?;
let llvm_vec4_i8 = get_llvm_type(ctx, &ast::Type::Vector(ast::ScalarType::U8, 4))?;
let src1 = ctx.names.value(arg.src1)?;
@@ -1674,9 +1677,24 @@ fn emit_inst_prmt(
unsafe { LLVMConstInt(llvm_i32, components[2] as _, 0) },
unsafe { LLVMConstInt(llvm_i32, components[3] as _, 0) },
];
- let mask = unsafe { LLVMConstVector(components_llvm.as_mut_ptr(), 4) };
- let shuffle_result =
+ let mask =
+ unsafe { LLVMConstVector(components_llvm.as_mut_ptr(), components_llvm.len() as u32) };
+ let mut shuffle_result =
unsafe { LLVMBuildShuffleVector(builder, src1_vector, src2_vector, mask, LLVM_UNNAMED) };
+ // In sext case I'd prefer to just emit V_PERM_B32 directly and be done with it,
+ // but V_PERM_B32 can sext only odd-indexed bytes.
+ let llvm_i8 = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U8))?;
+ let const_7 = unsafe { LLVMConstInt(llvm_i8, 7, 0) };
+ for (idx, requires_sext) in sext_components.iter().copied().enumerate() {
+ if !requires_sext {
+ continue;
+ }
+ let idx = unsafe { LLVMConstInt(llvm_i32, idx as u64, 0) };
+ let scalar = unsafe { LLVMBuildExtractElement(builder, shuffle_result, idx, LLVM_UNNAMED) };
+ let shift = unsafe { LLVMBuildAShr(builder, scalar, const_7, LLVM_UNNAMED) };
+ shuffle_result =
+ unsafe { LLVMBuildInsertElement(builder, shuffle_result, shift, idx, LLVM_UNNAMED) };
+ }
ctx.names.register_result(arg.dst, |dst_name| unsafe {
LLVMBuildBitCast(builder, shuffle_result, llvm_i32, dst_name)
});
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 88c3699..ae57575 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -1097,6 +1097,15 @@ InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-set
InstSet: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "set" <cmp_op:SetpCompareOp> <ftz:".ftz"?> ".f16x2" ".f16x2" <arg:Arg3> => {
+ let data = ast::SetData {
+ dst_type: ast::ScalarType::F16x2,
+ src_type: ast::ScalarType::F16x2,
+ flush_to_zero: ftz.is_some(),
+ cmp_op: cmp_op,
+ };
+ ast::Instruction::Set(data, arg)
+ },
"set" <cmp_op:SetpCompareOp> <ftz:".ftz"?> ".u32" ".f16x2" <arg:Arg3> => {
let data = ast::SetData {
dst_type: ast::ScalarType::U32,
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index bd745fd..e640765 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -271,7 +271,7 @@ test_ptx!(const, [0u16], [10u16, 20, 30, 40]);
test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]);
test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
test_ptx!(cvt_f32_f16, [0xa1u16], [0x37210000u32]);
-test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
+test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32, 0x6FFFD600]);
test_ptx!(
prmt_non_immediate,
[0x70c507d6u32, 0x6fbd4b5cu32],
diff --git a/ptx/src/test/spirv_run/prmt.ll b/ptx/src/test/spirv_run/prmt.ll
index a901ce4..87313c6 100644
--- a/ptx/src/test/spirv_run/prmt.ll
+++ b/ptx/src/test/spirv_run/prmt.ll
@@ -1,40 +1,60 @@
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
target triple = "amdgcn-amd-amdhsa"
-define protected amdgpu_kernel void @prmt(ptr addrspace(4) byref(i64) %"23", ptr addrspace(4) byref(i64) %"24") #0 {
-"31":
- %"8" = alloca i1, align 1, addrspace(5)
- store i1 false, ptr addrspace(5) %"8", align 1
- %"9" = alloca i1, align 1, addrspace(5)
- store i1 false, ptr addrspace(5) %"9", align 1
+define protected amdgpu_kernel void @prmt(ptr addrspace(4) byref(i64) %"32", ptr addrspace(4) byref(i64) %"33") #0 {
+"44":
+ %"10" = alloca i1, align 1, addrspace(5)
+ store i1 false, ptr addrspace(5) %"10", align 1
+ %"11" = alloca i1, align 1, addrspace(5)
+ store i1 false, ptr addrspace(5) %"11", align 1
%"4" = alloca i64, align 8, addrspace(5)
%"5" = alloca i64, align 8, addrspace(5)
%"6" = alloca i32, align 4, addrspace(5)
%"7" = alloca i32, align 4, addrspace(5)
- %"10" = load i64, ptr addrspace(4) %"23", align 8
- store i64 %"10", ptr addrspace(5) %"4", align 8
- %"11" = load i64, ptr addrspace(4) %"24", align 8
- store i64 %"11", ptr addrspace(5) %"5", align 8
- %"13" = load i64, ptr addrspace(5) %"4", align 8
- %"25" = inttoptr i64 %"13" to ptr
- %"12" = load i32, ptr %"25", align 4
- store i32 %"12", ptr addrspace(5) %"6", align 4
+ %"8" = alloca i32, align 4, addrspace(5)
+ %"9" = alloca i32, align 4, addrspace(5)
+ %"12" = load i64, ptr addrspace(4) %"32", align 8
+ store i64 %"12", ptr addrspace(5) %"4", align 8
+ %"13" = load i64, ptr addrspace(4) %"33", align 8
+ store i64 %"13", ptr addrspace(5) %"5", align 8
%"15" = load i64, ptr addrspace(5) %"4", align 8
- %"26" = inttoptr i64 %"15" to ptr
- %"33" = getelementptr inbounds i8, ptr %"26", i64 4
- %"14" = load i32, ptr %"33", align 4
- store i32 %"14", ptr addrspace(5) %"7", align 4
- %"17" = load i32, ptr addrspace(5) %"6", align 4
- %"18" = load i32, ptr addrspace(5) %"7", align 4
- %0 = bitcast i32 %"17" to <4 x i8>
- %1 = bitcast i32 %"18" to <4 x i8>
- %2 = shufflevector <4 x i8> %0, <4 x i8> %1, <4 x i32> <i32 4, i32 0, i32 6, i32 7>
- %"27" = bitcast <4 x i8> %2 to i32
- store i32 %"27", ptr addrspace(5) %"7", align 4
- %"19" = load i64, ptr addrspace(5) %"5", align 8
+ %"34" = inttoptr i64 %"15" to ptr
+ %"14" = load i32, ptr %"34", align 4
+ store i32 %"14", ptr addrspace(5) %"6", align 4
+ %"17" = load i64, ptr addrspace(5) %"4", align 8
+ %"35" = inttoptr i64 %"17" to ptr
+ %"46" = getelementptr inbounds i8, ptr %"35", i64 4
+ %"16" = load i32, ptr %"46", align 4
+ store i32 %"16", ptr addrspace(5) %"7", align 4
+ %"19" = load i32, ptr addrspace(5) %"6", align 4
%"20" = load i32, ptr addrspace(5) %"7", align 4
- %"30" = inttoptr i64 %"19" to ptr
- store i32 %"20", ptr %"30", align 4
+ %0 = bitcast i32 %"19" to <4 x i8>
+ %1 = bitcast i32 %"20" to <4 x i8>
+ %2 = shufflevector <4 x i8> %0, <4 x i8> %1, <4 x i32> <i32 4, i32 0, i32 6, i32 7>
+ %"36" = bitcast <4 x i8> %2 to i32
+ store i32 %"36", ptr addrspace(5) %"8", align 4
+ %"22" = load i32, ptr addrspace(5) %"6", align 4
+ %"23" = load i32, ptr addrspace(5) %"7", align 4
+ %3 = bitcast i32 %"22" to <4 x i8>
+ %4 = bitcast i32 %"23" to <4 x i8>
+ %5 = shufflevector <4 x i8> %3, <4 x i8> %4, <4 x i32> <i32 4, i32 0, i32 6, i32 7>
+ %6 = extractelement <4 x i8> %5, i32 0
+ %7 = ashr i8 %6, 7
+ %8 = insertelement <4 x i8> %5, i8 %7, i32 0
+ %9 = extractelement <4 x i8> %8, i32 2
+ %10 = ashr i8 %9, 7
+ %11 = insertelement <4 x i8> %8, i8 %10, i32 2
+ %"39" = bitcast <4 x i8> %11 to i32
+ store i32 %"39", ptr addrspace(5) %"9", align 4
+ %"24" = load i64, ptr addrspace(5) %"5", align 8
+ %"25" = load i32, ptr addrspace(5) %"8", align 4
+ %"42" = inttoptr i64 %"24" to ptr
+ store i32 %"25", ptr %"42", align 4
+ %"26" = load i64, ptr addrspace(5) %"5", align 8
+ %"27" = load i32, ptr addrspace(5) %"9", align 4
+ %"43" = inttoptr i64 %"26" to ptr
+ %"48" = getelementptr inbounds i8, ptr %"43", i64 4
+ store i32 %"27", ptr %"48", align 4
ret void
}
diff --git a/ptx/src/test/spirv_run/prmt.ptx b/ptx/src/test/spirv_run/prmt.ptx
index ba339e8..9901177 100644
--- a/ptx/src/test/spirv_run/prmt.ptx
+++ b/ptx/src/test/spirv_run/prmt.ptx
@@ -11,13 +11,17 @@
.reg .u64 out_addr;
.reg .u32 temp1;
.reg .u32 temp2;
+ .reg .u32 temp3;
+ .reg .u32 temp4;
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
ld.u32 temp1, [in_addr];
ld.u32 temp2, [in_addr+4];
- prmt.b32 temp2, temp1, temp2, 30212;
- st.u32 [out_addr], temp2;
+ prmt.b32 temp3, temp1, temp2, 30212;
+ prmt.b32 temp4, temp1, temp2, 32268;
+ st.u32 [out_addr], temp3;
+ st.u32 [out_addr+4], temp4;
ret;
}